From b84c533dad4db495a92fc6d390a7db5ebd938a88 Mon Sep 17 00:00:00 2001
From: Kentaro Kawakami <kawakami.k@fujitsu.com>
Date: Tue, 1 Nov 2022 09:33:41 +0900
Subject: [PATCH] cpu: aarch64: reorder: support jit-ed blk_reorder

---
 src/cpu/aarch64/jit_generator.hpp             |   20 +
 src/cpu/aarch64/jit_uni_reorder.cpp           | 2315 +++++++++++++----
 src/cpu/aarch64/jit_uni_reorder.hpp           |  183 +-
 src/cpu/aarch64/jit_uni_reorder_utils.cpp     |  482 ++--
 .../reorder/cpu_reorder_regular_f32_f32.cpp   |    6 +
 .../reorder/cpu_reorder_regular_f32_s32.cpp   |    2 +
 .../reorder/cpu_reorder_regular_f32_s8.cpp    |    2 +
 .../reorder/cpu_reorder_regular_f32_u8.cpp    |    2 +
 src/cpu/reorder/cpu_reorder_regular_s32.cpp   |    2 +
 src/cpu/reorder/cpu_reorder_regular_s8.cpp    |    2 +
 src/cpu/reorder/cpu_reorder_regular_u8.cpp    |    2 +
 11 files changed, 2272 insertions(+), 746 deletions(-)

diff --git a/src/cpu/aarch64/jit_generator.hpp b/src/cpu/aarch64/jit_generator.hpp
index dd781a622e1..12de9fa8c01 100644
--- a/src/cpu/aarch64/jit_generator.hpp
+++ b/src/cpu/aarch64/jit_generator.hpp
@@ -435,6 +435,26 @@ class jit_generator : public Xbyak_aarch64::CodeGenerator, public c_compatible {
                 Xbyak_aarch64::ZRegD(z3.getIdx()));
     }
 
+    void uni_ld1rw(const Xbyak_aarch64::VReg4S &dst,
+            const Xbyak_aarch64::XReg &base, const int64_t off) {
+        if (off == 0) {
+            ld1r(dst, ptr(base));
+        } else {
+            add_imm(X_DEFAULT_ADDR, base, off, X_TMP_0);
+            ld1r(dst, ptr(X_DEFAULT_ADDR));
+        }
+    }
+
+    void uni_ld1rw(const Xbyak_aarch64::ZRegS &dst,
+            const Xbyak_aarch64::XReg &base, const int64_t off) {
+        if (-32 <= off && off < 32) {
+            ld1rw(dst, P_ALL_ONE / Xbyak_aarch64::T_z, ptr(base, (int)off));
+        } else {
+            add_imm(X_DEFAULT_ADDR, base, off, X_TMP_0);
+            ld1rw(dst, P_ALL_ONE / Xbyak_aarch64::T_z, ptr(X_DEFAULT_ADDR));
+        }
+    }
+
     void uni_ldr(
             const Xbyak_aarch64::VReg &dst, const Xbyak_aarch64::XReg &addr) {
         ldr(Xbyak_aarch64::QReg(dst.getIdx()), ptr(addr));
diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp
index a6cefaa20e8..a708da808c0 100644
--- a/src/cpu/aarch64/jit_uni_reorder.cpp
+++ b/src/cpu/aarch64/jit_uni_reorder.cpp
@@ -1,6 +1,6 @@
 /*******************************************************************************
-* Copyright 2018-2021 Intel Corporation
-* Copyright 2020-2021 FUJITSU LIMITED
+* Copyright 2018-2022 Intel Corporation
+* Copyright 2020-2022 FUJITSU LIMITED
 * Copyright 2022 Arm Ltd. and affiliates
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,19 +19,21 @@
 #include <assert.h>
 #include <numeric>
 
-#include "dnnl_debug.h"
+#include "oneapi/dnnl/dnnl_debug.h"
 
 #include "common/c_types_map.hpp"
+#include "common/dnnl_thread.hpp"
 #include "common/memory_desc_wrapper.hpp"
 #include "common/nstl.hpp"
 #include "common/primitive.hpp"
 #include "common/type_helpers.hpp"
 #include "common/utils.hpp"
 
-#include "cpu/aarch64/jit_uni_reorder.hpp"
 #include "cpu/cpu_primitive.hpp"
 #include "cpu/reorder/cpu_reorder_pd.hpp"
 
+#include "cpu/aarch64/jit_uni_reorder.hpp"
+
 #include "cpu/aarch64/jit_generator.hpp"
 
 // #define TR_DEBUG
@@ -67,23 +69,6 @@ static bool prb_has_small_strides(const prb_t &prb) {
     return true;
 }
 
-static bool prb_tail_friendly(const prb_t &prb) {
-    /* find optimal ndims to makes it easier to
-     * identify the blk_chunk in the loop*/
-    int ndims = prb.full_ndims - prb.ndims;
-
-    int n = prb.nodes[0].is;
-    for (int d = 1; d < prb.ndims; ++d) {
-        if (d != prb.blk_chunk_idx) n *= prb.nodes[d].n;
-    }
-    if (prb.ip_tail > 0
-            && ((ndims == 0 && n != 1)
-                    || (ndims > 0 && prb.ndims > prb.blk_chunk_idx)))
-        return false;
-
-    return true;
-}
-
 /** Minimal reasonable/desirable kernel size.
  * The constant might be used to determine how a problem should be split
  * between kernel and threading driver. */
@@ -96,6 +81,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
     void operator()(const call_param_t *c) const override {
         jit_generator::operator()(c);
     }
+    void operator()(const tail_call_param_t *c) const override {
+        jit_generator::operator()(c);
+    }
 
     status_t create_kernel() override { return jit_generator::create_kernel(); }
 
@@ -105,30 +93,53 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
     };
 
     struct simple_impl_desc_t {
-        int ndims_full_unroll;
-        int len_last_dim_unroll;
-        int len_unroll;
+        int ndims_full_unroll = 0;
+        int len_last_dim_unroll = 0;
+        int tail_len_unroll = 0;
+        int len_unroll = 0;
     };
 
+#define PARAM(x) \
+    abi_param1, \
+            prb_.is_tail_present ? offsetof(tail_call_param_t, base_params) \
+                    + offsetof(call_param_t, x) \
+                                 : offsetof(call_param_t, x)
+#define TAIL_PARAM(x) abi_param1, offsetof(tail_call_param_t, x)
+
     static bool simple_impl_desc_init(
             const prb_t &prb, simple_impl_desc_t *desc) {
         const int ndims = prb.ndims;
 
         int ndims_full_unroll = 0;
         int len_last_dim_unroll = 1;
+        int tail_len_unroll = 0;
         int len_unroll = 1;
 
-        for (int d = 0; d < ndims; ++d) {
-            auto &node = prb.nodes[d];
-            if (len_unroll * node.n <= len_unroll_max) {
-                ndims_full_unroll++;
-                len_unroll *= node.n;
-            } else {
-                len_last_dim_unroll = len_unroll_max / len_unroll;
-                while (node.n % len_last_dim_unroll)
-                    --len_last_dim_unroll;
-                len_unroll *= len_last_dim_unroll;
-                break;
+        // It is responsible for finding as many values
+        // as kernel can unroll. If tail is present then
+        // kernel will unroll only last node (possible improvement).
+        // If there is no tail kernel can unroll a few nodes without any loops etc.
+        // ndims_full_unroll - how many nodes will be unrolled
+        // len_last_dim_unroll - what piece of last unrolled node will be unrolled
+        if (prb.is_tail_present) {
+            ndims_full_unroll = 1;
+            len_unroll = prb.nodes[0].n;
+            tail_len_unroll = prb.nodes[0].is_zero_pad_needed
+                    ? 0
+                    : static_cast<int>(prb.nodes[0].tail_size);
+        } else {
+            for (int d = 0; d < ndims; ++d) {
+                const auto &node = prb.nodes[d];
+                if (len_unroll * node.n <= len_unroll_max) {
+                    ndims_full_unroll++;
+                    len_unroll *= node.n;
+                } else {
+                    len_last_dim_unroll = len_unroll_max / len_unroll;
+                    while (node.n % len_last_dim_unroll)
+                        --len_last_dim_unroll;
+                    len_unroll *= len_last_dim_unroll;
+                    break;
+                }
             }
         }
 
@@ -137,6 +148,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
         if (desc) {
             desc->ndims_full_unroll = ndims_full_unroll;
             desc->len_last_dim_unroll = len_last_dim_unroll;
+            desc->tail_len_unroll = tail_len_unroll;
             desc->len_unroll = len_unroll;
         }
 
@@ -151,62 +163,69 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
                 && utils::one_of(p.otype, f32, s32, data_type::s8, u8)
                 && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
                 && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
-                && simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p)
-                && prb_tail_friendly(p);
-        if (!ok) return false;
+                && simple_impl_desc_init(p, nullptr)
+                && prb_has_small_strides(p);
 
-        return true;
+        return ok;
     }
 
-    int n(int d) {
-        assert(d < prb_.ndims);
-        return (int)prb_.nodes[d].n;
-    }
-    int is(int d) {
-        assert(d < prb_.ndims);
-        return (int)prb_.nodes[d].is;
-    }
-    int os(int d) {
-        assert(d < prb_.ndims);
-        return (int)prb_.nodes[d].os;
+    XReg o_addr(int o_off, bool with_type_multiplier = true) {
+        if (o_off) {
+            add_imm(X_DEFAULT_ADDR, x_ptr_out_off,
+                    o_off * (with_type_multiplier ? otype_sz_ : 1), X_TMP_0);
+            return X_DEFAULT_ADDR;
+        }
+
+        return x_ptr_out_off;
     }
-    int ss(int d) {
-        assert(d < prb_.ndims);
-        return (int)prb_.nodes[d].ss;
+
+    XReg c_addr(int c_off) {
+        if (c_off) {
+            add_imm(X_DEFAULT_ADDR, x_ptr_comp_off, c_off, X_TMP_0);
+            return X_DEFAULT_ADDR;
+        }
+
+        return x_ptr_comp_off;
     }
 
-    int blk_cnt() {
-        assert(prb_.blk_chunk_idx < prb_.full_ndims);
-        return (int)prb_.nodes[prb_.blk_chunk_idx].n - 1;
+    XReg data_chunk_addr(int node_id) {
+        add_imm(X_DEFAULT_ADDR, abi_param1,
+                offsetof(tail_call_param_t, curr_data_chunks)
+                        + sizeof(int64_t) * (node_id),
+                X_TMP_0);
+        return X_DEFAULT_ADDR;
     }
-    int op_padding() { return prb_.op_tail ? prb_.iblock - prb_.op_tail : 0; }
-    int ip_padding() { return prb_.ip_tail ? prb_.oblock - prb_.ip_tail : 0; }
 
     void step(int off, int prev_i_off, int prev_o_off, int prev_s_off,
-            int &i_off, int &o_off, int &s_off, int step_size = 1) {
+            int prev_c_off, int &i_off, int &o_off, int &s_off, int &c_off,
+            int step_size = 1) {
         i_off = prev_i_off;
         o_off = prev_o_off;
         s_off = prev_s_off;
+        c_off = prev_c_off;
 
         if (off == 0) return;
 
         int start_dim = 0, dims_prod = 1;
         for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim)
-            dims_prod *= n(start_dim);
+            dims_prod *= prb_.n(start_dim);
         assert(start_dim < prb_.ndims);
         off /= step_size;
 
-        for (int d = start_dim; d < prb_.ndims; ++d) {
-            i_off += is(d);
-            o_off += os(d);
-            s_off += ss(d);
+        for (int dim_id = start_dim; dim_id < prb_.ndims; ++dim_id) {
+            i_off += prb_.is(dim_id);
+            o_off += prb_.os(dim_id);
+            s_off += prb_.ss(dim_id);
+            c_off += prb_.cs(dim_id);
+
+            if (off % prb_.n(dim_id)) break;
 
-            if (off % n(d)) break;
+            i_off += -prb_.n(dim_id) * prb_.is(dim_id);
+            o_off += -prb_.n(dim_id) * prb_.os(dim_id);
+            s_off += -prb_.n(dim_id) * prb_.ss(dim_id);
+            c_off += -prb_.n(dim_id) * prb_.cs(dim_id);
 
-            i_off += -n(d) * is(d);
-            o_off += -n(d) * os(d);
-            s_off += -n(d) * ss(d);
-            off /= n(d);
+            off /= prb_.n(dim_id);
 
             if (off == 0) break; /* FIXME: is it really required? */
         }
@@ -215,8 +234,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
     void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off,
             int step_size = 1) {
         int dummy = 0;
-        step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy,
-                step_size);
+        step(off, prev_i_off, prev_o_off, dummy, dummy, i_off, o_off, dummy,
+                dummy, step_size);
     }
 
     void tr8x8_sve256(int i_off, int o_off) {
@@ -278,40 +297,36 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
                         && interim_f32);
         const uint64_t sveLen = get_sve_length();
 
-        add_imm(X_TMP_0, XReg(x_ptr_in_off), i_off * itype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_1, X_TMP_0, is(0) * itype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_2, X_TMP_1, is(0) * itype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_3, X_TMP_2, is(0) * itype_sz, X_DEFAULT_ADDR);
-
-        if (unroll * itype_sz == 32)
-            for (uint32_t i = 0; i < 4; i++)
-                ld1w(ZRegS {i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
-        else if (unroll * itype_sz == 16)
-            for (uint32_t i = 0; i < 4; i++)
-                ldr(QReg {i}, ptr(x_tmp_vec[i]));
-        else if (unroll * itype_sz == 8)
-            for (uint32_t i = 0; i < 4; i++)
-                ldr(DReg {i}, ptr(x_tmp_vec[i]));
-
-        add_imm(X_TMP_0, X_TMP_3, is(0) * itype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_1, X_TMP_0, is(0) * itype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_2, X_TMP_1, is(0) * itype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_3, X_TMP_2, is(0) * itype_sz, X_DEFAULT_ADDR);
-
-        if (unroll * itype_sz == 32)
-            for (uint32_t i = 0; i < 4; i++)
-                ld1w(ZRegS {4 + i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
-        else if (unroll * itype_sz == 16)
-            for (uint32_t i = 0; i < 4; i++)
-                ldr(QReg {4 + i}, ptr(x_tmp_vec[i]));
-        else if (unroll * itype_sz == 8)
-            for (uint32_t i = 0; i < 4; i++)
-                ldr(DReg {4 + i}, ptr(x_tmp_vec[i]));
+        PReg p_size(DUMMY_IDX);
+        switch (unroll * itype_sz_) {
+            case 32: p_size = p_lsb_256; break;
+            case 16: p_size = p_lsb_128; break;
+            case 8: p_size = p_lsb_64; break;
+            default: assert(!"unreachable");
+        }
+
+        const int node_0_input_stride = prb_.is(0);
+        add_imm(X_TMP_0, XReg(x_ptr_in_off), itype_sz_ * i_off, X_DEFAULT_ADDR);
+        for (int i = 1; i < unroll / 2; i++) {
+            add_imm(x_tmp_vec[i], x_tmp_vec[i - 1],
+                    itype_sz_ * node_0_input_stride, X_DEFAULT_ADDR);
+        }
+
+        for (uint32_t i = 0; i < unroll / 2; i++)
+            ld1w(ZRegS {i}, p_size / T_z, ptr(x_tmp_vec[i]));
+
+        for (int i = 0; i < unroll / 2; i++) {
+            add_imm(x_tmp_vec[i], x_tmp_vec[(i + 3) % 4],
+                    itype_sz_ * node_0_input_stride, X_DEFAULT_ADDR);
+        }
+
+        for (uint32_t i = 0; i < unroll / 2; i++)
+            ld1w(ZRegS {4 + i}, p_size / T_z, ptr(x_tmp_vec[i]));
 
         if (interim_f32) cvt2ps(0, unroll, prb_.itype);
 
 #if 0
-        /* Deubg code */
+        /* Debug code */
         index(z0.s, 0, 1);
         mov(z0.s, P_NOT_256/T_m, 0);
         mov(z_tmp_vec[0].s, 16);
@@ -348,9 +363,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
         for (uint32_t i = 0; i < unroll / 2; i++) {
             ZRegB z {unroll / 2 + i};
             ZRegB z_tmp = z_tmp_vec[unroll / 2 + i].b;
-            /* Move bit 128-255 to 0-127. */
-            ext(z, z, 16);
             /* Move bit 0-127 to 128-255. */
+            ext(z, z, 16);
+            /* Move bit 128-255 to 0-127. */
             ext(z_tmp, z_tmp, sveLen - 16);
         }
 
@@ -363,65 +378,64 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
         }
 
         if (need_saturation) {
-            init_saturate_f32(ymm_zero, ymm_saturation_ubound, reg_tmp,
+            init_saturate_f32(ymm_zero_, ymm_saturation_ubound_, reg_tmp_,
                     interim_f32 ? f32 : prb_.itype, prb_.otype);
             for (int i = 0; i < unroll; i++)
-                saturate_f32(ZRegS(i), ymm_zero, ymm_saturation_ubound,
-                        prb_.otype, p_all);
+                saturate_f32(ZRegS(i), ymm_zero_, ymm_saturation_ubound_,
+                        prb_.otype, P_ALL_ONE);
         }
 
         if (prb_.otype != f32)
             cvt2odt(0, unroll, prb_.otype, interim_f32 ? f32 : prb_.itype);
 
-        add_imm(X_TMP_0, XReg(x_ptr_out_off), o_off * otype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_1, X_TMP_0, os(1) * otype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_2, X_TMP_1, os(1) * otype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_3, X_TMP_2, os(1) * otype_sz, X_DEFAULT_ADDR);
-
-        if (unroll * otype_sz == 32)
-            for (uint32_t i = 0; i < 4; i++)
-                st1w(ZRegS {i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
-        else if (unroll * otype_sz == 16)
-            for (uint32_t i = 0; i < 4; i++)
-                str(QReg {i}, ptr(x_tmp_vec[i]));
-        else if (unroll * otype_sz == 8)
-            for (uint32_t i = 0; i < 4; i++)
-                str(DReg {i}, ptr(x_tmp_vec[i]));
-
-        add_imm(X_TMP_0, X_TMP_3, os(1) * otype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_1, X_TMP_0, os(1) * otype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_2, X_TMP_1, os(1) * otype_sz, X_DEFAULT_ADDR);
-        add_imm(X_TMP_3, X_TMP_2, os(1) * otype_sz, X_DEFAULT_ADDR);
-
-        if (unroll * otype_sz == 32)
-            for (uint32_t i = 0; i < 4; i++)
-                st1w(ZRegS {4 + i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
-        else if (unroll * otype_sz == 16)
-            for (uint32_t i = 0; i < 4; i++)
-                str(QReg {4 + i}, ptr(x_tmp_vec[i]));
-        else if (unroll * otype_sz == 8)
-            for (uint32_t i = 0; i < 4; i++)
-                str(DReg {4 + i}, ptr(x_tmp_vec[i]));
+        const int node_1_output_stride = prb_.os(1);
+
+        switch (unroll * otype_sz_) {
+            case 32: p_size = p_lsb_256; break;
+            case 16: p_size = p_lsb_128; break;
+            case 8: p_size = p_lsb_64; break;
+            default: assert(!"unreachable");
+        }
+
+        add_imm(X_TMP_0, XReg(x_ptr_out_off), otype_sz_ * o_off,
+                X_DEFAULT_ADDR);
+        for (int i = 1; i < unroll / 2; i++) {
+            add_imm(x_tmp_vec[i], x_tmp_vec[i - 1],
+                    otype_sz_ * node_1_output_stride, X_DEFAULT_ADDR);
+        }
+
+        for (uint32_t i = 0; i < 4; i++)
+            st1w(ZRegS {i}, p_size / T_z, ptr(x_tmp_vec[i]));
+
+        for (int i = 0; i < unroll / 2; i++) {
+            add_imm(x_tmp_vec[i], x_tmp_vec[(i + 3) % 4],
+                    otype_sz_ * node_1_output_stride, X_DEFAULT_ADDR);
+        }
+
+        for (uint32_t i = 0; i < unroll / 2; i++)
+            st1w(ZRegS {4 + i}, p_size / T_z, ptr(x_tmp_vec[i]));
     }
 
     bool can_do_tr8x8() {
         using namespace data_type;
 
-        return get_sve_length() >= Xbyak_aarch64::util::SVE_256
-                && prb_.ndims >= 2
+        static constexpr int desirable_node_size = 8;
+        static constexpr int desirable_stride = 1;
+
+        return mayiuse(sve_256) && prb_.ndims >= 2
                 && ((utils::one_of(prb_.itype, u8, data_type::s8, s32, f32)
                         && utils::one_of(
                                 prb_.otype, u8, data_type::s8, s32, f32)))
-                && utils::everyone_is(8, n(0), n(1))
-                && utils::everyone_is(1, os(0), is(1))
-                && utils::everyone_is(0, prb_.ip_tail, prb_.op_tail)
+                && utils::everyone_is(desirable_node_size, prb_.n(0), prb_.n(1))
+                && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(1))
+                && !prb_.is_tail_present
                 && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f;
     }
 
-    bool process_unroll_tr8x8(int len) {
+    bool process_unroll_tr8x8(const int ndims, const int len) {
         if (!can_do_tr8x8()) return false;
 
-        const int step_size = n(0) * n(1);
+        const int step_size = prb_.n(0) * prb_.n(1);
         int i_off = 0, o_off = 0;
         for (int off = 0; off < len; off += step_size) {
             step(off, i_off, o_off, i_off, o_off, step_size);
@@ -432,23 +446,56 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
     }
 
     template <cpu_isa_t isa>
-    bool process_direct_copy(int len) {
+    bool process_direct_copy(const int ndims, const int len) {
         using namespace data_type;
 
-        const int simd_w = cpu_isa_traits<isa>::vlen / itype_sz;
-        bool can_do = true && mayiuse(isa)
-                && utils::everyone_is(1, os(0), is(0))
-                && (false || prb_.itype == prb_.otype
+        static constexpr int desirable_stride = 1;
+        using TRegS =
+                typename utils::conditional<isa == asimd, VReg4S, ZRegS>::type;
+        const int simd_w = cpu_isa_traits<isa>::vlen / itype_sz_;
+
+        // TODO: support tail_processing for direct copy
+
+        const bool do_src_zp = prb_.req_src_zp;
+        const bool do_dst_zp = prb_.req_dst_zp;
+        const bool zp_applicable = IMPLICATION(
+                (do_src_zp || do_dst_zp), utils::one_of(prb_.itype, s32, f32));
+        const bool can_do = true && mayiuse(isa)
+                && compensation_needed_ == false
+                && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(0))
+                && (false || (prb_.itype == prb_.otype ? zp_applicable : false)
                         || (prb_.itype == s32 && prb_.otype == f32)
                         || (prb_.itype == f32 && prb_.otype == s32))
-                && len % simd_w == 0 && n(0) % len == 0
-                && prb_.ip_tail % simd_w == 0 && prb_.op_tail % simd_w == 0
+                && len % simd_w == 0 && prb_.n(0) % len == 0
+                && !prb_.is_tail_present
                 && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f;
         if (!can_do) return false;
 
+        static constexpr int vmm_zp_last_idx = 15;
+        const auto vmm_src_zp
+                = TRegS(do_dst_zp ? vmm_zp_last_idx - 1 : vmm_zp_last_idx);
+        if (do_src_zp) {
+            uni_ld1rw(vmm_src_zp, PARAM(src_zp));
+            uni_scvtf(vmm_src_zp, vmm_src_zp);
+        }
+        const auto vmm_dst_zp = TRegS(vmm_zp_last_idx);
+        if (do_dst_zp) {
+            uni_ld1rw(vmm_dst_zp, PARAM(dst_zp));
+            uni_scvtf(vmm_dst_zp, vmm_dst_zp);
+        }
+
+        const auto apply_zp_ps = [&](const TRegS vmm) {
+            if (do_src_zp) fsub(vmm, vmm, vmm_src_zp);
+            if (do_dst_zp) fadd(vmm, vmm, vmm_dst_zp);
+        };
+
         for (int off = 0; off < len;) {
-            const int unroll
+            // TODO: we need extra reg for proper saturation if otype == s32
+            int unroll
                     = nstl::min(16 - (prb_.otype == s32), (len - off) / simd_w);
+            unroll = (do_src_zp || do_dst_zp)
+                    ? nstl::min(unroll, 16 - do_src_zp - do_dst_zp)
+                    : unroll;
 
             int ur = 0;
             int tmp_ur = 0;
@@ -458,14 +505,11 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
 
                 do {
                     add_imm(x_tmp_vec[count++], x_ptr_in_off,
-                            (off + ur * simd_w) * itype_sz, X_DEFAULT_ADDR);
+                            (off + ur * simd_w) * itype_sz_, X_DEFAULT_ADDR);
                     ur++;
                 } while (ur < unroll && count < x_tmp_vec_size);
 
                 for (int i = 0; i < count; i++) {
-                    /*                    if (vlen == 64)
-                        ldr(ZReg(tmp_ur + i), ptr(x_tmp_vec[i]));
-                        else */
                     if (vlen == 64 || vlen == 32)
                         ld1w(ZRegS(tmp_ur + i), p_lsb_256 / T_z,
                                 ptr(x_tmp_vec[i]));
@@ -478,33 +522,28 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
             }
 
             if (prb_.itype != prb_.otype) {
-                const int vlen = cpu_isa_traits<isa>::vlen;
                 for (int ur = 0; ur < unroll; ++ur) {
+                    TRegS r(ur);
                     if (prb_.itype == s32 && prb_.otype == f32) {
-                        if (vlen == 64 || vlen == 32) {
-                            ZRegS r(ur);
-                            /* MSB side 256 bits are ignored. */
-                            scvtf(r, p_all / T_m, r);
-                        } else if (vlen == 16) {
-                            VReg4S r(ur);
-                            scvtf(r, r);
-                        } else
-                            assert(!"unreachable");
+                        uni_scvtf(r, r);
                     } else if (prb_.itype == f32 && prb_.otype == s32) {
-                        /* Out of order can be expected. */
-                        if (vlen == 64 || vlen == 32) {
-                            ZRegS r(ur);
-                            frinti(r, p_all / T_m, r);
-                            fcvtzs(r, p_all / T_m, r);
-                        } else if (vlen == 16) {
-                            VReg4S r(ur);
-                            frinti(r, r);
-                            fcvtzs(r, r);
-                        } else
-                            assert(!"unreachable");
+                        uni_frinti(r, r);
+                        uni_fcvtzs(r, r);
                     } else
                         assert(!"unreachable");
                 }
+            } else if (do_src_zp || do_dst_zp) {
+                for (int ur = 0; ur < unroll; ++ur) {
+                    const auto vmm = TRegS(ur);
+                    if (prb_.otype == f32) {
+                        apply_zp_ps(vmm);
+                    } else if (prb_.otype == s32) {
+                        uni_scvtf(vmm, vmm);
+                        apply_zp_ps(vmm);
+                        uni_frinti(vmm, vmm);
+                        uni_fcvtzs(vmm, vmm);
+                    }
+                }
             }
 
             ur = 0;
@@ -515,7 +554,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
 
                 do {
                     add_imm(x_tmp_vec[count++], x_ptr_out_off,
-                            (off + ur * simd_w) * otype_sz, X_DEFAULT_ADDR);
+                            (off + ur * simd_w) * otype_sz_, X_DEFAULT_ADDR);
                     ur++;
                 } while (ur < unroll && count < x_tmp_vec_size);
 
@@ -538,8 +577,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
     }
 
     void process_unroll_generic_step(int reg_unroll, const int *i_off,
-            const int *o_off, const int *s_off, const int *ip_padding,
-            const bool h_padded) {
+            const int *o_off, const int *s_off, const int *c_off,
+            const int *zero_padding, const bool tail_processing) {
         using namespace data_type;
 
         auto cvt2ps
@@ -588,76 +627,84 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
             }
         };
 
+        auto load_bytes_addr = [=](const int ur, const int r) {
+            add_imm(x_tmp_vec[r], x_ptr_in_off, i_off[ur + r] * itype_sz_,
+                    X_DEFAULT_ADDR);
+        };
+        auto load_bytes = [=](const int ur, int size, int r) {
+            switch (size) {
+                case 4: ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); break;
+                case 2: ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); break;
+                case 1: ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); break;
+                default: assert(!"unreachable");
+            }
+        };
+
+        auto store = [=](const XReg &addr, const VReg ymm, int size) {
+            const uint32_t xmm = ymm.getIdx();
+            switch (size) {
+                case 16: str(QReg(xmm), ptr(addr)); break;
+                case 8: str(DReg(xmm), ptr(addr)); break;
+                case 4: str(SReg(xmm), ptr(addr)); break;
+                case 2: str(HReg(xmm), ptr(addr)); break;
+                case 1: str(BReg(xmm), ptr(addr)); break;
+                default: assert(!"unreachable");
+            }
+        };
+
         /* check whether loading 4 values at once is possible */
-        bool can_load_xmm = reg_unroll % 4 == 0;
+        static constexpr int xmm_vlen = 4;
+        bool can_load_xmm = reg_unroll % xmm_vlen == 0;
         for (int ur = 1; ur < reg_unroll; ++ur)
-            if (i_off[ur] != i_off[ur - 1] + 1) can_load_xmm = false;
-        const int load_step = can_load_xmm ? 4 : 1;
+            if (i_off[ur] != i_off[ur - 1] + 1) {
+                can_load_xmm = false;
+                break;
+            }
+        const int load_step = can_load_xmm ? xmm_vlen : 1;
 
         /* check whether storing 4 values at once is possible */
-        bool can_store_xmm = reg_unroll % 4 == 0;
+        bool can_store_xmm = reg_unroll % xmm_vlen == 0;
         for (int ur = 1; ur < reg_unroll; ++ur)
-            if (o_off[ur] != o_off[ur - 1] + 1) can_store_xmm = false;
+            if (o_off[ur] != o_off[ur - 1] + 1) {
+                can_store_xmm = false;
+                break;
+            }
         const int ur_step = can_store_xmm ? 4 : 1;
         const int load_tail_step
                 = !can_load_xmm && can_store_xmm ? ur_step : load_step;
 
-        const bool interim_f32 = false
-                || utils::one_of(f32, prb_.itype, prb_.otype)
-                || prb_.scale_type != scale_type_t::NONE || prb_.beta != 0.f;
+        const bool interim_f32 = interim_f32_needed();
 
         const bool need_saturation
                 = (utils::one_of(prb_.otype, u8, data_type::s8, s32)
                         && interim_f32);
-        if (h_padded) {
+
+        std::vector<int> store_masks;
+        if (tail_processing) {
             for (int ur = 0; ur < reg_unroll; ur += load_tail_step) {
-                if (itype_sz == 4)
-                    movi(VReg4S(ur), 0);
-                else if (itype_sz == 2)
-                    movi(VReg8H(ur), 0);
-                else
-                    movi(VReg16B(ur), 0);
-                /* x_tmp_vec = X_TMP_0 - X_TMP_4
-                 Do not use X_TMP_? as the last arg. */
-                for (int r = 0; r < load_tail_step; ++r) {
-                    if (ip_padding[ur + r] == 0) {
-                        add_imm(x_tmp_vec[r], x_ptr_in_off,
-                                i_off[ur + r] * itype_sz, X_DEFAULT_ADDR);
-                    }
-                }
+                uni_clear(VReg(ur));
+                store_masks.push_back(0);
 
                 for (int r = 0; r < load_tail_step; ++r) {
-                    if (ip_padding[ur + r] == 0) {
-                        if (itype_sz == 4)
-                            ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r]));
-                        else if (itype_sz == 2)
-                            ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r]));
-                        else
-                            ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r]));
+                    if (zero_padding[ur + r] == 0) {
+                        store_masks.back() += 1 << r;
+                        load_bytes_addr(ur, r);
                     }
                 }
+
+                for (int r = 0; r < load_tail_step; ++r)
+                    if (zero_padding[ur + r] == 0) load_bytes(ur, itype_sz_, r);
             }
         } else {
             if (!can_load_xmm && can_store_xmm) {
-                assert(ur_step == 4);
+                assert(ur_step == xmm_vlen);
                 /* load with stride */
                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
-
-                    /* x_tmp_vec = X_TMP_0 - X_TMP_4
-                 Do not use X_TMP_? as the last arg. */
                     for (int r = 0; r < ur_step; ++r) {
-                        add_imm(x_tmp_vec[r], x_ptr_in_off,
-                                i_off[ur + r] * itype_sz, X_DEFAULT_ADDR);
-                    }
-
-                    for (int r = 0; r < ur_step; ++r) {
-                        if (itype_sz == 4)
-                            ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r]));
-                        else if (itype_sz == 2)
-                            ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r]));
-                        else
-                            ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r]));
+                        load_bytes_addr(ur, r);
                     }
+                    for (int r = 0; r < ur_step; ++r)
+                        load_bytes(ur, itype_sz_, r);
                 }
             } else {
                 int ur = 0;
@@ -667,13 +714,13 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
 
                     do {
                         add_imm(x_tmp_vec[count++], x_ptr_in_off,
-                                i_off[ur] * itype_sz, X_DEFAULT_ADDR);
+                                i_off[ur] * itype_sz_, X_DEFAULT_ADDR);
                         ur += load_step;
                     } while (ur < reg_unroll && count < x_tmp_vec_size);
 
                     for (int i = 0; i < count; i++) {
 
-                        switch (load_step * itype_sz) {
+                        switch (load_step * itype_sz_) {
                             case 16:
                                 ldr(QReg(tmp_ur), ptr(x_tmp_vec[i]));
                                 break;
@@ -688,6 +735,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
                 }
             }
         }
+
         /* xmm[:] <-- (f32)xmm[:] */
         if (interim_f32) {
             const int cvt_step = nstl::max(load_step, ur_step);
@@ -702,30 +750,32 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
             if (fast_return) {
                 if (prb_.scale_type == scale_type_t::COMMON)
                     for (int ur = 0; ur < reg_unroll; ur += load_step)
-                        fmul(VReg4S(ur), VReg4S(ur), xmm_scale);
+                        fmul(VReg4S(ur), VReg4S(ur), xmm_scale_);
                 if (prb_.otype != f32) {
-                    init_saturate_f32(xmm_zero, xmm_saturation_ubound, reg_tmp,
-                            interim_f32 ? f32 : prb_.itype, prb_.otype);
-                    for (int ur = 0; ur < reg_unroll; ur += load_step)
+                    init_saturate_f32(xmm_zero_, xmm_saturation_ubound_,
+                            reg_tmp_, interim_f32 ? f32 : prb_.itype,
+                            prb_.otype);
+                    for (int ur = 0; ur < reg_unroll; ur += load_step) {
                         if (need_saturation)
-                            saturate_f32(VReg4S(ur), xmm_zero,
-                                    xmm_saturation_ubound, prb_.otype, p_all);
+                            saturate_f32(VReg4S(ur), xmm_zero_,
+                                    xmm_saturation_ubound_, prb_.otype,
+                                    P_ALL_ONE);
+                    }
 
                     for (int ur = 0; ur < reg_unroll; ur += load_step)
                         cvt2odt(ur, 1, prb_.otype,
                                 interim_f32 ? f32 : prb_.itype);
                 }
-                /* load_step is 1 or 4. */
                 for (int ur = 0; ur < reg_unroll; ur += load_step) {
                     for (int r = 0; r < load_step; ++r) {
                         add_imm(x_tmp_vec[r], x_ptr_out_off,
-                                o_off[ur + r] * otype_sz, X_DEFAULT_ADDR);
+                                o_off[ur + r] * otype_sz_, X_DEFAULT_ADDR);
                     }
 
                     for (int r = 0; r < load_step; ++r) {
-                        if (otype_sz == 4)
+                        if (otype_sz_ == 4)
                             st1(VReg4S(ur)[r], ptr(x_tmp_vec[r]));
-                        else if (otype_sz == 2)
+                        else if (otype_sz_ == 2)
                             st1(VReg8H(ur)[r], ptr(x_tmp_vec[r]));
                         else
                             st1(VReg16B(ur)[r], ptr(x_tmp_vec[r]));
@@ -735,7 +785,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
             }
 
             /* scatter elements of xmm into 4 xmms */
-            if (itype_sz == 4 || interim_f32) {
+            if (itype_sz_ == 4 || interim_f32) {
                 for (int ur = 0; ur < reg_unroll; ur += load_step)
                     for (int r = 1; r < load_step; ++r) {
                         VReg4S v(ur);
@@ -747,7 +797,18 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
                 for (int ur = 0; ur < reg_unroll; ur += load_step)
                     for (int r = 1; r < load_step; ++r)
                         ext(VReg16B(ur + r), VReg16B(ur), VReg16B(ur),
-                                itype_sz * r);
+                                itype_sz_ * r);
+            }
+        }
+
+        /* src zero point application */
+        if (prb_.req_src_zp) {
+            for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+                const auto xmm = VReg4S(ur);
+                if (interim_f32)
+                    fsub(xmm, xmm, xmm_src_zp_);
+                else
+                    sub(xmm, xmm, xmm_src_zp_);
             }
         }
 
@@ -756,7 +817,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
             /* xmm <-- scale * xmm[:] */
             if (prb_.scale_type == scale_type_t::COMMON) {
                 for (int ur = 0; ur < reg_unroll; ur += ur_step)
-                    fmul(VReg4S(ur), VReg4S(ur), xmm_scale);
+                    fmul(VReg4S(ur), VReg4S(ur), xmm_scale_);
             } else if (prb_.scale_type == scale_type_t::MANY) {
                 enum class scale_load_type_t { bcast, load, gather };
 
@@ -769,13 +830,12 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
                             scale_load_type = scale_load_type_t::load;
 
                     if (scale_load_type == scale_load_type_t::bcast
-                            && !h_padded) {
-                        VReg4S v(xmm_scale.getIdx());
+                            && !tail_processing) {
+                        VReg4S v(xmm_scale_.getIdx());
                         VReg4S v_dst(ur);
-                        add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz,
+                        add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz_,
                                 X_DEFAULT_ADDR);
-                        ldr(W_TMP_0, ptr(X_TMP_0));
-                        dup(v, W_TMP_0);
+                        ld1r(v, ptr(X_TMP_0));
                         fmul(v_dst, v_dst, v);
                         continue;
                     }
@@ -786,10 +846,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
                             scale_load_type = scale_load_type_t::gather;
 
                     if (scale_load_type == scale_load_type_t::load
-                            && !h_padded) {
-                        uint32_t idx = xmm_scale.getIdx();
+                            && !tail_processing) {
+                        uint32_t idx = xmm_scale_.getIdx();
                         VReg4S v_dst(ur);
-                        add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz,
+                        add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz_,
                                 X_DEFAULT_ADDR);
 
                         ldr(QReg {idx}, ptr(X_TMP_0));
@@ -799,22 +859,15 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
 
                     // load doesn't work as well
                     // so gather the scale factors one by one
-                    /*ur_step is 1 or 4. */
-                    for (int r = ur; r < ur + ur_step; ++r) {
-                        if (ip_padding[r] == 0 || !h_padded) {
-                            /* x_tmp_vec = X_TMP_0 - X_TMP_4
-                         Do not use X_TMP_? as the last arg. */
+                    for (int r = ur; r < ur + ur_step; ++r)
+                        if (zero_padding[r] == 0 || !tail_processing) {
                             add_imm(x_tmp_vec[r - ur], x_ptr_scale_off,
-                                    s_off[r] * stype_sz, X_DEFAULT_ADDR);
-                        }
-                    }
-                    for (int r = ur; r < ur + ur_step; ++r) {
-                        if (ip_padding[r] == 0 || !h_padded) {
-                            VReg4S v(xmm_scale.getIdx());
-                            ld1(v[r - ur], ptr(x_tmp_vec[r - ur]));
+                                    s_off[r] * stype_sz_, X_DEFAULT_ADDR);
                         }
-                    }
-                    fmul(VReg4S(ur), VReg4S(ur), xmm_scale);
+                    for (int r = ur; r < ur + ur_step; ++r)
+                        if (zero_padding[r] == 0 || !tail_processing)
+                            ld1(xmm_scale_[r - ur], ptr(x_tmp_vec[r - ur]));
+                    fmul(VReg4S(ur), VReg4S(ur), xmm_scale_);
                 }
             }
 
@@ -829,7 +882,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
 
                     do {
                         add_imm(x_tmp_vec[count++], x_ptr_out_off,
-                                o_off[ur] * otype_sz, X_DEFAULT_ADDR);
+                                o_off[ur] * otype_sz_, X_DEFAULT_ADDR);
                         ur += ur_step;
                     } while (ur < reg_unroll && count < x_tmp_vec_size);
 
@@ -873,7 +926,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
             if (prb_.scale_type == scale_type_t::COMMON) {
                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
                     VReg4S tmp(ur);
-                    fmul(tmp, tmp, VReg4S(xmm_scale.getIdx()));
+                    fmul(tmp, tmp, VReg4S(xmm_scale_.getIdx()));
                 }
             } else if (prb_.scale_type == scale_type_t::MANY) {
                 int ur = 0;
@@ -883,7 +936,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
 
                     do {
                         add_imm(x_tmp_vec[count++], x_ptr_scale_off,
-                                s_off[ur] * stype_sz, X_DEFAULT_ADDR);
+                                s_off[ur] * stype_sz_, X_DEFAULT_ADDR);
                         ur += ur_step;
                     } while (ur < reg_unroll && count < x_tmp_vec_size);
 
@@ -908,7 +961,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
 
                     do {
                         add_imm(x_tmp_vec[count++], x_ptr_out_off,
-                                o_off[ur] * otype_sz, X_DEFAULT_ADDR);
+                                o_off[ur] * otype_sz_, X_DEFAULT_ADDR);
                         ur += ur_step;
                     } while (ur < reg_unroll && count < (x_tmp_vec_size / 2));
 
@@ -951,94 +1004,272 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
             }
         }
 
-        if (need_saturation) {
-            init_saturate_f32(
-                    xmm_zero, xmm_saturation_ubound, reg_tmp, f32, prb_.otype);
+        /* dst zero point application */
+        if (prb_.req_dst_zp) {
             for (int ur = 0; ur < reg_unroll; ur += ur_step) {
-                saturate_f32(VReg4S(ur), xmm_zero, xmm_saturation_ubound,
-                        prb_.otype, p_all);
+                const auto xmm = VReg4S(ur);
+                if (interim_f32)
+                    fadd(xmm, xmm, xmm_dst_zp_);
+                else
+                    add(xmm, xmm, xmm_dst_zp_);
             }
         }
 
-        for (int ur = 0; ur < reg_unroll; ur += ur_step) {
-            if (prb_.otype != f32)
-                cvt2odt(ur, 1, prb_.otype, interim_f32 ? f32 : prb_.itype);
+        /* adjust scale application */
+        if (prb_.scale_adjust != 1.f) {
+            dup(xmm_tmp_, reg_scale_adjust_);
+            for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+                fmul(VReg4S(ur), VReg4S(ur), xmm_tmp_);
+            }
+        }
+
+        if (need_saturation) {
+            init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, reg_tmp_, f32,
+                    prb_.otype);
+            for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+                saturate_f32(VReg4S(ur), xmm_zero_, xmm_saturation_ubound_,
+                        prb_.otype, P_ALL_ONE);
+            }
+
+            // reset back xmm_zero_ if needed.
+            if (compensation_needed_ && (prb_.req_src_zp || prb_.req_dst_zp))
+                uni_clear(VReg(xmm_zero_.getIdx()));
         }
 
-        int ur = 0;
-        int tmp_ur = 0;
-        while (ur < reg_unroll) {
-            int count = 0;
+        if (compensation_needed_) {
+            const uint32_t xmm_begin = 9;
+            const uint32_t xmm_end = 11;
+            uint32_t xmm_id = xmm_begin;
+            const auto get_temp_xmm = [&] {
+                const Xbyak_aarch64::VReg temp {xmm_id++};
+
+                if (xmm_id > xmm_end) { xmm_id = xmm_begin; }
+
+                return temp;
+            };
+            if (can_store_xmm) {
+                enum class comp_load_type_t { bcast, load, gather };
+
+                for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+
+                    bool all_ip_padding_one = true;
+                    bool all_ip_padding_zero = true;
+                    for (int r = ur; r < ur + ur_step; r++) {
+                        if (zero_padding[r] != 1)
+                            all_ip_padding_one = false;
+                        else
+                            all_ip_padding_zero = false;
+                    }
+                    if (all_ip_padding_one) continue;
+
+                    comp_load_type_t comp_load_type = comp_load_type_t::bcast;
+
+                    for (int r = ur + 1; r < ur + ur_step; ++r)
+                        if (c_off[r] != c_off[r - 1] + 0) {
+                            comp_load_type = comp_load_type_t::load;
+                            break;
+                        }
 
-            do {
-                add_imm(x_tmp_vec[count++], x_ptr_out_off, o_off[ur] * otype_sz,
-                        X_DEFAULT_ADDR);
-                ur += ur_step;
-            } while (ur < reg_unroll && count < x_tmp_vec_size);
+                    if (comp_load_type == comp_load_type_t::bcast
+                            && all_ip_padding_zero) {
+                        const auto reduction_xmm = get_temp_xmm().s4;
+                        const auto xmm_reorder_result = VReg4S(ur);
+                        frinti(reduction_xmm, xmm_reorder_result);
+                        addv(SReg(reduction_xmm.getIdx()), reduction_xmm);
+                        const auto comp_addr = c_addr(c_off[ur]);
+                        const auto xmm_tmp_ = get_temp_xmm().s4;
+                        ldr(SReg(xmm_tmp_.getIdx()), ptr(comp_addr));
+                        add(xmm_tmp_, xmm_tmp_, reduction_xmm);
+                        str(SReg(xmm_tmp_.getIdx()), ptr(comp_addr));
+                        continue;
+                    }
+
+                    if (comp_load_type == comp_load_type_t::load)
+                        for (int r = ur + 1; r < ur + ur_step; ++r)
+                            if (c_off[r] != c_off[r - 1] + 1) {
+                                comp_load_type = comp_load_type_t::gather;
+                                break;
+                            }
+
+                    if (comp_load_type == comp_load_type_t::load
+                            && all_ip_padding_zero) {
+                        const auto xmm_reorder_result_dq = get_temp_xmm().s4;
+                        const auto xmm_reorder_result = VReg4S(ur);
+                        const auto comp_addr = c_addr(c_off[ur]);
+                        frinti(xmm_reorder_result_dq, xmm_reorder_result);
+                        const auto xmm_tmp_ = get_temp_xmm().s4;
+                        ldr(SReg(xmm_tmp_.getIdx()), ptr(comp_addr));
+                        add(xmm_reorder_result_dq, xmm_reorder_result_dq,
+                                xmm_tmp_);
+                        str(SReg(xmm_tmp_.getIdx()), ptr(comp_addr));
+                        continue;
+                    }
 
-            for (int i = 0; i < count; i++) {
+                    const auto xmm_reorder_result_dq = get_temp_xmm().s4;
+                    const auto xmm_reorder_result = VReg4S(ur);
+                    frinti(xmm_reorder_result_dq, xmm_reorder_result);
 
-                switch (ur_step * otype_sz) {
-                    case 16: str(QReg(tmp_ur), ptr(x_tmp_vec[i])); break;
-                    case 8: str(DReg(tmp_ur), ptr(x_tmp_vec[i])); break;
-                    case 4: str(SReg(tmp_ur), ptr(x_tmp_vec[i])); break;
-                    case 2: str(HReg(tmp_ur), ptr(x_tmp_vec[i])); break;
-                    case 1: str(BReg(tmp_ur), ptr(x_tmp_vec[i])); break;
-                    default: assert(!"unreachable");
+                    for (int r = ur; r < ur + ur_step; ++r) {
+                        if (zero_padding[r] == 0 || !tail_processing) {
+                            mov(W_TMP_0, xmm_reorder_result_dq[r]);
+                            const auto comp_addr = c_addr(c_off[ur]);
+                            str(W_TMP_0, ptr(comp_addr));
+                        }
+                    }
+                }
+            } else {
+                for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+                    if (zero_padding[ur] == 0 || !tail_processing) {
+                        const auto xmm_reorder_result_dq = get_temp_xmm().s4;
+                        const auto xmm_reorder_result = VReg4S(ur);
+                        const auto comp_addr = c_addr(c_off[ur]);
+                        frinti(xmm_reorder_result_dq, xmm_reorder_result);
+                        const auto xmm_tmp_ = get_temp_xmm().s4;
+                        ldr(SReg(xmm_tmp_.getIdx()), ptr(comp_addr));
+                        add(xmm_reorder_result_dq, xmm_reorder_result_dq,
+                                xmm_tmp_);
+                        str(SReg(xmm_tmp_.getIdx()), ptr(comp_addr));
+                    }
                 }
-                tmp_ur += ur_step;
             }
         }
+
+        for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+            if (prb_.req_src_zp || prb_.req_dst_zp) {
+                const bool use_store_masks = !store_masks.empty();
+                if (use_store_masks) {
+                    const auto mask = (~store_masks[ur / ur_step]) & 0xF;
+                    switch (mask) {
+                        case 0x0:
+                            /* Do nothing */
+                            break;
+                        case 0x1: ins(VReg4S(ur)[0], xmm_zero_[0]); break;
+                        case 0x2: ins(VReg4S(ur)[1], xmm_zero_[1]); break;
+                        case 0x3:
+                            ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]);
+                            break;
+                        case 0x4: ins(VReg4S(ur)[2], xmm_zero_[2]); break;
+                        case 0x5:
+                            ins(VReg4S(ur)[0], xmm_zero_[0]);
+                            ins(VReg4S(ur)[2], xmm_zero_[2]);
+                            break;
+                        case 0x6:
+                            ins(VReg4S(ur)[1], xmm_zero_[1]);
+                            ins(VReg4S(ur)[2], xmm_zero_[2]);
+                            break;
+                        case 0x7:
+                            ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]);
+                            ins(VReg4S(ur)[2], xmm_zero_[2]);
+                            break;
+                        case 0x8: ins(VReg4S(ur)[3], xmm_zero_[3]); break;
+                        case 0x9:
+                            ins(VReg4S(ur)[0], xmm_zero_[0]);
+                            ins(VReg4S(ur)[3], xmm_zero_[3]);
+                            break;
+                        case 0xa:
+                            ins(VReg4S(ur)[1], xmm_zero_[1]);
+                            ins(VReg4S(ur)[3], xmm_zero_[3]);
+                            break;
+                        case 0xb:
+                            ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]);
+                            ins(VReg4S(ur)[3], xmm_zero_[3]);
+                            break;
+                        case 0xc:
+                            ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]);
+                            break;
+                        case 0xd:
+                            ins(VReg4S(ur)[0], xmm_zero_[0]);
+                            ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]);
+                            break;
+                        case 0xe:
+                            ins(VReg4S(ur)[1], xmm_zero_[1]);
+                            ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]);
+                            break;
+                        case 0xf: movi(VReg16B(ur), 0); break;
+                        default: assert(!"unreachable");
+                    }
+                }
+            }
+            if (prb_.otype != f32)
+                cvt2odt(ur, 1, prb_.otype, interim_f32 ? f32 : prb_.itype);
+
+            store(o_addr(o_off[ur]), VReg(ur), ur_step * otype_sz_);
+        }
     }
 
-    void comp_padding_flag(int ndims, int off, int len, int &i_tail) {
-        const int ip_without_padding
-                = ndims == 0 ? len - ip_padding() : prb_.ip_tail;
-        if ((ndims == 0 && off >= ip_without_padding)
-                || (ndims > 0 && (off % prb_.oblock) >= ip_without_padding))
-            i_tail = 1;
+    bool interim_f32_needed() {
+        using namespace data_type;
+
+        return utils::one_of(f32, prb_.itype, prb_.otype)
+                || prb_.scale_type != scale_type_t::NONE || prb_.beta != 0.f
+                || ((prb_.req_src_zp || prb_.req_dst_zp)
+                                ? !(prb_.itype == s32 && prb_.otype == s32)
+                                : false)
+                || (prb_.itype != f32 && compensation_needed_)
+                || prb_.scale_adjust != 1.f;
     }
 
-    void process_unroll_generic(const int ndims, int len, const bool h_padded) {
+    void process_unroll_generic(
+            const int ndims, int len, const bool tail_processing) {
+        assert(IMPLICATION(prb_.nodes[0].tail_size > 0,
+                len == static_cast<int>(prb_.nodes[0].n)
+                        || len == static_cast<int>(prb_.nodes[0].tail_size)));
+
         const int blk = 8;
 
         int i_off[2 * blk] = {0};
         int o_off[2 * blk] = {0};
         int s_off[2 * blk] = {0};
+        int c_off[2 * blk] = {0};
 
         int curr = 0; // will switch between 0 and 1
 
+        const bool interim_f32 = interim_f32_needed();
+
+        if (prb_.req_src_zp) {
+            add_imm(X_DEFAULT_ADDR, PARAM(src_zp), X_TMP_0);
+            ld1r(xmm_src_zp_, ptr(X_DEFAULT_ADDR));
+            if (interim_f32) scvtf(xmm_src_zp_, xmm_src_zp_);
+        }
+        if (prb_.req_dst_zp) {
+            add_imm(X_DEFAULT_ADDR, PARAM(dst_zp), X_TMP_0);
+            ld1r(xmm_dst_zp_, ptr(X_DEFAULT_ADDR));
+            if (interim_f32) scvtf(xmm_dst_zp_, xmm_dst_zp_);
+        }
+
         for (int off = 0; off < len; off += blk) {
             const int reg_unroll = nstl::min(off + blk, len) - off;
-            int ip_padding[blk] = {0};
+            int zero_padding[blk] = {0};
+            const auto curr_blk = curr * blk;
 
             /* compute offsets and tail*/
             for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) {
-                const int ur_c = curr * blk + ur;
+                const int ur_c = curr_blk + ur;
                 const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur
+                const bool is_tail
+                        = off + ur >= static_cast<int>(prb_.nodes[0].tail_size);
                 step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p],
-                        i_off[ur_c], o_off[ur_c], s_off[ur_c]);
-                if (h_padded)
-                    comp_padding_flag(ndims, off + ur, len, ip_padding[ur]);
+                        c_off[ur_p], i_off[ur_c], o_off[ur_c], s_off[ur_c],
+                        c_off[ur_c]);
+                if (tail_processing && is_tail) zero_padding[ur] = 1;
             }
-            process_unroll_generic_step(reg_unroll, i_off + curr * blk,
-                    o_off + curr * blk, s_off + curr * blk, ip_padding,
-                    h_padded);
+
+            process_unroll_generic_step(reg_unroll, i_off + curr_blk,
+                    o_off + curr_blk, s_off + curr_blk, c_off + curr_blk,
+                    zero_padding, tail_processing);
 
             curr = 1 - curr;
         }
     }
 
     void compute_ker(
-            const int ndims, const int len_unroll, const bool h_padded) {
+            const int ndims, const int len_unroll, const bool tail_processing) {
         bool optimized = false;
-        optimized = optimized
-                || (process_direct_copy<sve_256>(len_unroll) && !h_padded);
-        optimized = optimized
-                || (process_direct_copy<asimd>(len_unroll) && !h_padded);
-        optimized
-                = optimized || (process_unroll_tr8x8(len_unroll) && !h_padded);
-        if (!optimized) process_unroll_generic(ndims, len_unroll, h_padded);
+        optimized = optimized || process_direct_copy<sve_256>(ndims, len_unroll)
+                || process_direct_copy<asimd>(ndims, len_unroll)
+                || process_unroll_tr8x8(ndims, len_unroll);
+        if (!optimized)
+            process_unroll_generic(ndims, len_unroll, tail_processing);
     }
 
     void loop_begin(Label &l, XReg reg_cnt, int len) {
@@ -1046,97 +1277,287 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
         L(l);
     }
 
+    void check_if_this_is_last_chunk(const XReg reg_curr_chunk, int node_id) {
+        // Chunks are backwards numered i.e:
+        // [0] -> [node_size]
+        // [1] -> [node_size - 1]
+        // ...
+        // [node_size - 1] -> [1]
+
+        // It is done like this, because it is easier to decrement counter
+        // and check if it is equal to zero than increment and check
+        // if it is equal to node_size.
+        static constexpr int64_t last_chunk = 1;
+        cmp(reg_curr_chunk, last_chunk);
+    }
+
+    void zero_dst_memory(const int bytes_to_zeroing) {
+        static constexpr int num_of_bytes_in_xmm = 128 / 8;
+
+        const int xmms_to_zeroing
+                = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).quot;
+        const int tail_to_zeroing
+                = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).rem;
+
+        movi(xmm_tmp_, 0);
+
+        if (xmms_to_zeroing > 0) {
+            Label loop;
+
+            mov(reg_tmp_, xmms_to_zeroing);
+            L(loop);
+            str(QReg(xmm_tmp_.getIdx()), ptr(o_addr(0)));
+            add_imm(reg_off_out_, reg_off_out_, num_of_bytes_in_xmm, X_TMP_0);
+            add_imm(x_ptr_out_off, x_ptr_out_off, num_of_bytes_in_xmm, X_TMP_0);
+            subs(reg_tmp_, reg_tmp_, 1);
+            mov(X_TMP_0, 32);
+            b(NE, loop);
+        }
+
+        if (tail_to_zeroing) mov_imm(W_TMP_0, 0);
+
+        for (int i = 0; i < tail_to_zeroing; i++)
+            strb(W_TMP_0, ptr(o_addr(i, false)));
+
+        // Restore dst offset to initial value
+        if (xmms_to_zeroing > 0) {
+            sub_imm(reg_off_out_, reg_off_out_,
+                    num_of_bytes_in_xmm * xmms_to_zeroing, X_TMP_0);
+            sub_imm(x_ptr_out_off, x_ptr_out_off,
+                    num_of_bytes_in_xmm * xmms_to_zeroing, X_TMP_0);
+        }
+    }
+
+    void finalize_tail_loop(int i_step, int o_step, int s_step, int c_step,
+            const int curr_node_id) {
+        static constexpr int empty_chunk_info = -1;
+
+        mov(reg_tmp_, empty_chunk_info);
+        str(reg_tmp_, ptr(data_chunk_addr(curr_node_id)));
+
+        const int padded_area = prb_.nodes[curr_node_id].n
+                - prb_.nodes[curr_node_id].tail_size;
+
+        if (prb_.nodes[curr_node_id].is_zero_pad_needed) {
+            int num_of_zero_padded_values = padded_area;
+            for (int i = curr_node_id - 1; i >= 0; i--) {
+                num_of_zero_padded_values *= prb_.nodes[i].n;
+            }
+
+            const int bytes_to_zeroing = num_of_zero_padded_values * otype_sz_;
+            zero_dst_memory(bytes_to_zeroing);
+        }
+
+        // This function is called by loop_end. At the end
+        // of loop_end is section that is responsible for
+        // restoring offset values. Restoring is based on
+        // len value which is equal to prb.nodes[x].n.
+        // If fill_zero_padded_area is called then it means
+        // offsets were shifted prb.nodes[x].tail_size times.
+        // Therefore, this function has to shift offsets by
+        // zero pad area.
+        add_imm(reg_off_in_, reg_off_in_, padded_area * i_step * itype_sz_,
+                X_TMP_0);
+        add_imm(reg_off_out_, reg_off_out_, padded_area * o_step * otype_sz_,
+                X_TMP_0);
+        add_imm(x_ptr_in_off, x_ptr_in_off, padded_area * i_step * itype_sz_,
+                X_TMP_0);
+        add_imm(x_ptr_out_off, x_ptr_out_off, padded_area * o_step * otype_sz_,
+                X_TMP_0);
+        if (prb_.scale_type == scale_type_t::MANY) {
+            add_imm(reg_off_scale_, reg_off_scale_,
+                    padded_area * s_step * stype_sz_, X_TMP_0);
+            add_imm(x_ptr_scale_off, x_ptr_scale_off,
+                    padded_area * s_step * stype_sz_, X_TMP_0);
+        }
+        if (compensation_needed_) {
+            add_imm(reg_off_comp_, reg_off_comp_,
+                    padded_area * c_step * sizeof(int32_t), X_TMP_0);
+            add_imm(x_ptr_comp_off, x_ptr_comp_off,
+                    padded_area * c_step * sizeof(int32_t), X_TMP_0);
+        }
+    }
+
     void loop_end(Label &l, XReg reg_cnt, int len, int i_step, int o_step,
-            int s_step) {
-        add_imm(reg_off_in, reg_off_in, i_step * itype_sz, X_TMP_0);
-        add_imm(reg_off_out, reg_off_out, o_step * otype_sz, X_TMP_0);
-        add_imm(x_ptr_in_off, x_ptr_in_off, i_step * itype_sz, X_TMP_0);
-        add_imm(x_ptr_out_off, x_ptr_out_off, o_step * otype_sz, X_TMP_0);
+            int s_step, int c_step, const int curr_node_id) {
+        add_imm(reg_off_in_, reg_off_in_, i_step * itype_sz_, X_TMP_0);
+        add_imm(reg_off_out_, reg_off_out_, o_step * otype_sz_, X_TMP_0);
+        add_imm(x_ptr_in_off, x_ptr_in_off, i_step * itype_sz_, X_TMP_0);
+        add_imm(x_ptr_out_off, x_ptr_out_off, o_step * otype_sz_, X_TMP_0);
 
         if (prb_.scale_type == scale_type_t::MANY) {
-            add_imm(reg_off_scale, reg_off_scale, s_step * stype_sz, X_TMP_0);
-            add_imm(x_ptr_scale_off, x_ptr_scale_off, s_step * stype_sz,
+            add_imm(reg_off_scale_, reg_off_scale_, s_step * stype_sz_,
+                    X_TMP_0);
+            add_imm(x_ptr_scale_off, x_ptr_scale_off, s_step * stype_sz_,
                     X_TMP_0);
         }
+
+        if (compensation_needed_) {
+            add_imm(reg_off_comp_, reg_off_comp_, c_step * sizeof(int32_t),
+                    X_TMP_0);
+            add_imm(x_ptr_comp_off, x_ptr_comp_off, c_step * sizeof(int32_t),
+                    X_TMP_0);
+        }
+
         subs(reg_cnt, reg_cnt, 1);
         b(NE, l);
 
-        sub_imm(reg_off_in, reg_off_in, len * i_step * itype_sz, X_TMP_0);
-        sub_imm(reg_off_out, reg_off_out, len * o_step * otype_sz, X_TMP_0);
-        sub_imm(x_ptr_in_off, x_ptr_in_off, len * i_step * itype_sz, X_TMP_0);
-        sub_imm(x_ptr_out_off, x_ptr_out_off, len * o_step * otype_sz, X_TMP_0);
+        if (prb_.tail(curr_node_id) != 0) {
+            Label if_end;
+
+            // On the stack should be an information if node
+            // was processed with tail or not.
+            ldr(reg_tmp_, post_ptr(X_SP, reg_tmp_.getBit() / 8));
+
+            cmp(reg_tmp_, with_tail_info_);
+            b(NE, if_end);
+            finalize_tail_loop(i_step, o_step, s_step, c_step, curr_node_id);
+            L(if_end);
+        }
+
+        // Restore offset to initial values. It means before
+        // loop execution.
+        sub_imm(reg_off_in_, reg_off_in_, len * i_step * itype_sz_, X_TMP_0);
+        sub_imm(reg_off_out_, reg_off_out_, len * o_step * otype_sz_, X_TMP_0);
+        sub_imm(x_ptr_in_off, x_ptr_in_off, len * i_step * itype_sz_, X_TMP_0);
+        sub_imm(x_ptr_out_off, x_ptr_out_off, len * o_step * otype_sz_,
+                X_TMP_0);
 
         if (prb_.scale_type == scale_type_t::MANY) {
-            sub_imm(reg_off_scale, reg_off_scale, len * s_step * stype_sz,
+            sub_imm(reg_off_scale_, reg_off_scale_, len * s_step * stype_sz_,
                     X_TMP_0);
-            sub_imm(x_ptr_scale_off, x_ptr_scale_off, len * s_step * stype_sz,
+            sub_imm(x_ptr_scale_off, x_ptr_scale_off, len * s_step * stype_sz_,
                     X_TMP_0);
         }
+        if (compensation_needed_) {
+            sub_imm(reg_off_comp_, reg_off_comp_,
+                    len * c_step * sizeof(int32_t), X_TMP_0);
+            sub_imm(x_ptr_comp_off, x_ptr_comp_off,
+                    len * c_step * sizeof(int32_t), X_TMP_0);
+        }
     }
 
-    void compute_blk_ker(const int len_unroll) {
+    void compute_blk_ker(const simple_impl_desc_t &desc) {
+        static constexpr bool with_tail_processing = true;
+        Label no_last_chunk, end_label;
         int omp_ndims = prb_.full_ndims - prb_.ndims;
-        Label no_last_blk, end_label;
 
-        if (prb_.ip_tail > 0 && prb_.op_tail == 0) {
-            if (omp_ndims == 0) {
-                cmp(reg_last_loop_cnt, 1);
-                bne(no_last_blk);
-                compute_ker(omp_ndims, len_unroll, true);
-            } else {
-                cmp(reg_blk_chunks, blk_cnt());
-                bne(no_last_blk);
-                compute_ker(omp_ndims, len_unroll, true);
+        if (prb_.nodes[0].tail_size > 0) {
+            if (!prb_.nodes[0].is_parent_empty()) {
+                const int parent_node_id = prb_.nodes[0].parent_node_id;
+                ldr(reg_tmp_, ptr(data_chunk_addr(parent_node_id)));
+                check_if_this_is_last_chunk(reg_tmp_, parent_node_id);
+                b(NE, no_last_chunk);
             }
+
+            const int len_unroll = desc.tail_len_unroll > 0
+                    ? desc.tail_len_unroll
+                    : desc.len_unroll;
+            compute_ker(omp_ndims, len_unroll, with_tail_processing);
             b(end_label);
         }
 
-        L(no_last_blk);
-        compute_ker(omp_ndims, len_unroll, false);
+        L(no_last_chunk);
+        compute_ker(omp_ndims, desc.len_unroll, !with_tail_processing);
         L(end_label);
     }
 
+    void create_loops(const simple_impl_desc_t &desc,
+            const std::array<const XReg, 3> &reg_cnt, int jit_loop) {
+        assert(jit_loop <= ndims_jit_loop_max);
+
+        if (jit_loop > 0) {
+            const int nfu = desc.ndims_full_unroll;
+            const int unroll_factor
+                    = jit_loop == 1 ? desc.len_last_dim_unroll : 1;
+            const int curr_node_id = nfu + (jit_loop - 1);
+            const int parent_node_id = prb_.nodes[curr_node_id].parent_node_id;
+            const int tail_size = prb_.tail(curr_node_id) / unroll_factor;
+            const int node_size = prb_.n(curr_node_id) / unroll_factor;
+            const XReg reg_loop_cnt = reg_cnt[jit_loop - 1];
+            const bool curr_node_has_tail = prb_.tail(curr_node_id) != 0;
+            Label loop, if_no_tail, if_end;
+
+            if (curr_node_has_tail) {
+                const size_t reg_bytes = reg_tmp_.getBit() / 8;
+                if (prb_.nodes[curr_node_id].is_parent_empty()) {
+                    mov(reg_loop_cnt, tail_size);
+                    // Put info that node is being processed with tail.
+                    mov(reg_tmp_, with_tail_info_);
+                    str(reg_tmp_, pre_ptr(X_SP, -reg_bytes));
+                } else {
+                    ldr(reg_tmp_, ptr(data_chunk_addr(parent_node_id)));
+                    check_if_this_is_last_chunk(reg_tmp_, parent_node_id);
+                    b(NE, if_no_tail);
+                    mov(reg_loop_cnt, tail_size);
+                    // Put info that node is being processed with tail.
+                    mov(reg_tmp_, with_tail_info_);
+                    str(reg_tmp_, pre_ptr(X_SP, -reg_bytes));
+                    b(if_end);
+
+                    L(if_no_tail);
+                    mov(reg_loop_cnt, node_size);
+                    // Put info that node is being processed without tail.
+                    mov(reg_tmp_, without_tail_info_);
+                    str(reg_tmp_, pre_ptr(X_SP, -reg_bytes));
+                    L(if_end);
+                }
+            }
+
+            if (prb_.is_tail_in_one_of_child_nodes(curr_node_id)) {
+                if (!curr_node_has_tail) {
+                    mov(reg_loop_cnt, node_size);
+                    str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id)));
+                }
+                L(loop);
+                if (!prb_.nodes[curr_node_id].is_parent_empty()) {
+                    Label if_no_tail_in_child_node;
+                    ldr(reg_tmp_, ptr(data_chunk_addr(parent_node_id)));
+                    check_if_this_is_last_chunk(reg_tmp_, parent_node_id);
+                    b(NE, if_no_tail_in_child_node);
+                    str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id)));
+                    L(if_no_tail_in_child_node);
+                } else {
+                    str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id)));
+                }
+            } else if (curr_node_has_tail) {
+                L(loop);
+            } else {
+                loop_begin(loop, reg_loop_cnt, node_size);
+            }
+            create_loops(desc, reg_cnt, jit_loop - 1);
+
+            loop_end(loop, reg_loop_cnt, node_size,
+                    prb_.is(curr_node_id) * unroll_factor,
+                    prb_.os(curr_node_id) * unroll_factor,
+                    prb_.ss(curr_node_id) * unroll_factor,
+                    prb_.cs(curr_node_id) * unroll_factor, curr_node_id);
+        } else {
+            compute_blk_ker(desc);
+        }
+    }
+
     bool simple_impl() {
         simple_impl_desc_t d;
         if (!simple_impl_desc_init(prb_, &d)) return false;
 
-        const int nfu = d.ndims_full_unroll;
-        const int ldu = d.len_last_dim_unroll;
-        const int n_jit_loops = prb_.ndims - d.ndims_full_unroll;
-        assert(n_jit_loops <= ndims_jit_loop_max);
-
-        eor(reg_off_in, reg_off_in, reg_off_in);
-        eor(reg_off_out, reg_off_out, reg_off_out);
-        mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx()));
-        mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx()));
+        eor(reg_off_in_, reg_off_in_, reg_off_in_);
+        eor(reg_off_out_, reg_off_out_, reg_off_out_);
+        mov(x_ptr_in_off, reg_ptr_in_);
+        mov(x_ptr_out_off, reg_ptr_out_);
         if (prb_.scale_type == scale_type_t::MANY) {
-            eor(reg_off_scale, reg_off_scale, reg_off_scale);
-            mov(x_ptr_scale_off, XReg(reg_ptr_scale.getIdx()));
+            mov(reg_off_scale_, 0);
+            mov(x_ptr_scale_off, reg_ptr_scale_);
+        }
+        if (compensation_needed_) {
+            eor(reg_off_comp_, reg_off_comp_, reg_off_comp_);
+            mov(x_ptr_comp_off, reg_off_comp_);
         }
 
-        Label l_loop[3];
-        XReg reg_cnt[3] = {x15, x14, x13};
-
-        if (n_jit_loops > 2) loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2));
-
-        if (n_jit_loops > 1) loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1));
-
-        if (n_jit_loops > 0)
-            loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu);
-
-        compute_blk_ker(d.len_unroll);
-
-        if (n_jit_loops > 0)
-            loop_end(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu, is(nfu + 0) * ldu,
-                    os(nfu + 0) * ldu, ss(nfu + 0) * ldu);
-
-        if (n_jit_loops > 1)
-            loop_end(l_loop[1], reg_cnt[1], n(nfu + 1), is(nfu + 1),
-                    os(nfu + 1), ss(nfu + 1));
+        std::array<const XReg, 3> reg_cnt({{x15, x14, x13}});
 
-        if (n_jit_loops > 2)
-            loop_end(l_loop[2], reg_cnt[2], n(nfu + 2), is(nfu + 2),
-                    os(nfu + 2), ss(nfu + 2));
+        const int n_jit_loops = prb_.ndims - d.ndims_full_unroll;
+        create_loops(d, reg_cnt, n_jit_loops);
 
         return true;
     }
@@ -1156,7 +1577,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
         inst(__VA_ARGS__);
 
     void cvt_z_s32_f32(const size_t startIdx, const size_t regNum) {
-        UNROLL_INST(scvtf, ZRegS, tmp, p_all / T_m, tmp);
+        UNROLL_INST(scvtf, ZRegS, tmp, P_ALL_ONE / T_m, tmp);
     }
 
     void cvt_v_s32_f32(const size_t startIdx, const size_t regNum) {
@@ -1164,8 +1585,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
     }
 
     void cvt_z_f32_s32(const size_t startIdx, const size_t regNum) {
-        UNROLL_INST(frinti, ZRegS, tmp, p_all / T_m, tmp);
-        UNROLL_INST(fcvtzs, ZRegS, tmp, p_all / T_m, tmp);
+        UNROLL_INST(frinti, ZRegS, tmp, P_ALL_ONE / T_m, tmp);
+        UNROLL_INST(fcvtzs, ZRegS, tmp, P_ALL_ONE / T_m, tmp);
     }
 
     void cvt_v_f32_s32(const size_t startIdx, const size_t regNum) {
@@ -1175,7 +1596,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
 
     void cvt_z_s8_s32(const size_t startIdx, const size_t regNum) {
         cvt_z_b_s(startIdx, regNum);
-        UNROLL_INST(sxtb, ZRegS, tmp, p_all / T_m, tmp);
+        UNROLL_INST(sxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp);
     }
 
     void cvt_v_s8_s32(const size_t startIdx, const size_t regNum) {
@@ -1214,7 +1635,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
 
     void cvt_z_u8_s32(const size_t startIdx, const size_t regNum) {
         cvt_z_b_s(startIdx, regNum);
-        UNROLL_INST(uxtb, ZRegS, tmp, p_all / T_m, tmp);
+        UNROLL_INST(uxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp);
     }
 
     void cvt_v_u8_s32(const size_t startIdx, const size_t regNum) {
@@ -1285,7 +1706,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
 
         dupm(z_tmp7.s, 255);
         UNROLL_INST2(smax, ZRegS(i), 0);
-        UNROLL_INST2(smin, ZRegS(i), p_all / T_m, z_tmp7.s);
+        UNROLL_INST2(smin, ZRegS(i), P_ALL_ONE / T_m, z_tmp7.s);
         UNROLL_INST(uzp1, ZRegH, tmp, tmp, tmp);
         UNROLL_INST(uzp1, ZRegB, tmp, tmp, tmp);
         UNROLL_INST2(mov, ZRegB(i), P_NOT_128 / T_m, 0);
@@ -1320,107 +1741,514 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
 #undef UNROLL_INST
 #undef UNROLL_INST
 
-    jit_uni_reorder_kernel_f32_t(const desc_t &desc) : kernel_t(desc) {
-        itype_sz = data_type_size(prb_.itype);
-        otype_sz = data_type_size(prb_.otype);
-        stype_sz = sizeof(float);
+    jit_uni_reorder_kernel_f32_t(const desc_t &desc)
+        : kernel_t(desc), isa_(get_max_cpu_isa()) {
+        assert(!utils::one_of(isa_, isa_undef, isa_all));
+        itype_sz_ = data_type_size(prb_.itype);
+        otype_sz_ = data_type_size(prb_.otype);
+        stype_sz_ = sizeof(float);
     }
 
     void generate() override {
         using namespace Xbyak_aarch64::util;
         uint64_t sveLen = get_sve_length();
+        Label end_of_kernel;
 
         preamble();
-#define PARAM(x) offsetof(call_param_t, x)
+
         if (prb_.scale_type == scale_type_t::COMMON) {
-            add_imm(X_DEFAULT_ADDR, abi_param1, PARAM(scale), X_TMP_1);
+            add_imm(X_DEFAULT_ADDR, PARAM(scale), X_TMP_1);
             ldr(X_TMP_0, ptr(X_DEFAULT_ADDR));
-            ldr(W_TMP_1, ptr(X_TMP_0));
-            dup(xmm_scale, W_TMP_1);
+            ld1r(xmm_scale_, ptr(X_TMP_0));
         } else if (prb_.scale_type == scale_type_t::MANY) {
-            add_imm(X_DEFAULT_ADDR, abi_param1, PARAM(scale), X_TMP_0);
-            ldr(reg_ptr_scale, ptr(X_DEFAULT_ADDR));
+            add_imm(X_DEFAULT_ADDR, PARAM(scale), X_TMP_0);
+            ldr(reg_ptr_scale_, ptr(X_DEFAULT_ADDR));
         }
-        add_imm(X_TMP_0, abi_param1, PARAM(in), X_TMP_2);
-        add_imm(X_TMP_1, abi_param1, PARAM(out), X_TMP_2);
-        add_imm(reg_blk, abi_param1, PARAM(blk_chunks), reg_blk);
-        ldr(reg_ptr_in, ptr(X_TMP_0));
-        ldr(reg_ptr_out, ptr(X_TMP_1));
-        ldr(reg_blk_chunks, ptr(reg_blk));
-
-#undef PARAM
-        mov_imm(reg_last_loop_cnt, 1);
+        if (compensation_needed_) {
+            add_imm(X_DEFAULT_ADDR, PARAM(compensation_scratch), X_TMP_0);
+            ldr(reg_ptr_comp_, ptr(X_DEFAULT_ADDR));
+        }
+        if (prb_.scale_adjust == 0.5f) { mov(reg_scale_adjust_, 0x3f000000); }
+        add_imm(X_TMP_0, PARAM(in), X_TMP_2);
+        add_imm(X_TMP_1, PARAM(out), X_TMP_2);
+        ldr(reg_ptr_in_, ptr(X_TMP_0));
+        ldr(reg_ptr_out_, ptr(X_TMP_1));
 
-        mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx()));
-        mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx()));
-        mov(x_ptr_scale_off, XReg(reg_ptr_scale.getIdx()));
+        mov(x_ptr_in_off, reg_ptr_in_);
+        mov(x_ptr_out_off, reg_ptr_out_);
+        mov(x_ptr_scale_off, reg_ptr_scale_);
+        mov(x_ptr_comp_off, reg_ptr_comp_);
 
         if (sveLen) { /* SVE is available. */
             ptrue(p_lsb_256.b, VL32);
-            ptrue(p_all.b);
+            ptrue(p_lsb_128.b, VL16);
+            ptrue(p_lsb_64.b, VL8);
         }
 
-        if (can_do_tr8x8()) {
-            dup(ymm_zero, 0);
-
-            if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
-                mov_imm(reg_tmp, 0x7f7f7f7f7f7f7f7f);
-                mov(VReg4S(ymm_8x127b.getIdx())[0], WReg(reg_tmp.getIdx()));
+        bool is_tail_in_drv_dims = false;
+        for (int i = prb_.ndims; i < prb_.full_ndims; i++)
+            if (prb_.nodes[i].tail_size > 0) {
+                is_tail_in_drv_dims = true;
+                break;
             }
-        } else if (mayiuse(sve_512)) {
-            movi(xmm_zero, 0);
 
-            if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
-                mov(WReg(reg_tmp.getIdx()), 0x7f7f7f7f);
-                mov(xmm_4x127b[0], WReg(reg_tmp.getIdx()));
+        if (is_tail_in_drv_dims) {
+            Label reorder_kernel;
+            add_imm(X_DEFAULT_ADDR, TAIL_PARAM(skip_kernel_execution), X_TMP_0);
+            ldr(reg_tmp_, ptr(X_DEFAULT_ADDR));
+            cmp(reg_tmp_, static_cast<int64_t>(true));
+            b(EQ, end_of_kernel);
+
+            add_imm(X_DEFAULT_ADDR, TAIL_PARAM(zeroing_data), X_TMP_0);
+            ldr(reg_tmp_, ptr(X_DEFAULT_ADDR));
+            cmp(reg_tmp_, static_cast<int64_t>(false));
+            b(EQ, reorder_kernel);
+            // If zeroing data is set then all dst memory
+            // will be zeroed and nothing more will be done.
+            int bytes_to_zeroing = otype_sz_;
+            for (int i = 0; i < prb_.ndims; i++) {
+                bytes_to_zeroing *= prb_.nodes[i].n;
             }
+            eor(reg_off_out_, reg_off_out_, reg_off_out_);
+            mov(x_ptr_out_off, reg_ptr_out_);
+            zero_dst_memory(bytes_to_zeroing);
+            b(end_of_kernel);
+            L(reorder_kernel);
+        }
+
+        if (can_do_tr8x8()) {
+            dup(ymm_zero_, 0);
+        } else {
+            movi(xmm_zero_, 0);
         }
 
         impl();
+
+        L(end_of_kernel);
         postamble();
     }
 
+    ~jit_uni_reorder_kernel_f32_t() override = default;
+
+#undef TAIL_PARAM
+#undef PARAM
+
 private:
-    int itype_sz;
-    int otype_sz;
-    int stype_sz;
+    static constexpr int64_t with_tail_info_ = static_cast<int64_t>(true);
+    static constexpr int64_t without_tail_info_ = static_cast<int64_t>(false);
+
+    int itype_sz_;
+    int otype_sz_;
+    int stype_sz_;
 
-    XReg reg_ptr_in = x6;
-    XReg reg_ptr_out = x2;
-    XReg reg_ptr_scale = abi_not_param1;
+    const cpu_isa_t isa_;
 
-    XReg reg_off_in = x8;
-    XReg reg_off_out = x9;
-    XReg reg_off_scale = x10;
+    const XReg reg_ptr_in_ = x6;
+    const XReg reg_ptr_out_ = x2;
+    const XReg reg_ptr_scale_ = abi_not_param1;
+    const XReg reg_ptr_comp_ = x3;
+    const WReg &reg_scale_adjust_ = w5;
 
-    XReg reg_blk = x11;
-    XReg reg_blk_chunks = x12;
-    XReg reg_last_loop_cnt = x11;
+    const XReg reg_off_in_ = x8;
+    const XReg reg_off_out_ = x9;
+    const XReg reg_off_scale_ = x10;
+    const XReg reg_off_comp_ = x11;
 
-    XReg reg_tmp = x0;
+    XReg reg_tmp_ = x12;
 
-    VReg4S xmm_scale = v15.s;
-    VReg4S xmm_zero = v14.s;
-    VReg4S xmm_4x127b = v13.s; // TODO: unite with ymm_zero
-    ZRegS ymm_zero = z14.s;
-    ZRegS ymm_8x127b = z13.s;
-    VReg4S xmm_tmp = v12.s;
-    VReg4S xmm_saturation_ubound = v12.s;
-    ZRegS ymm_saturation_ubound = z12.s;
+    VReg4S xmm_scale_ = v15.s;
+    VReg4S xmm_zero_ = v14.s;
+    ZRegS ymm_zero_ = z14.s;
+    VReg4S xmm_tmp_ = v12.s;
+    const VReg4S xmm_src_zp_ = v9.s;
+    const VReg4S xmm_dst_zp_ = v11.s;
+    VReg4S xmm_saturation_ubound_ = v12.s;
+    ZRegS ymm_saturation_ubound_ = z12.s;
 
     /* Note: x22 - x28 are already used as temporal registgers
        in jit_generator.hpp.
-       x_ptr_(in|out|scale)_off keeps (base + offset) address. */
+       x_ptr_(in|out|scale|comp)_off keeps (base + offset) address. */
     XReg x_ptr_in_off = x16;
     XReg x_ptr_out_off = x18;
     XReg x_ptr_scale_off = x20;
+    XReg x_ptr_comp_off = x17;
 
     /* Caution: Chose predicate registers not used by x64's implementation. */
     PReg p_lsb_256 = p7;
-    PReg p_all = p6;
+    PReg p_lsb_128 = p6;
+    PReg p_lsb_64 = p4;
     PReg p_tmp0 = p5;
 
     const std::vector<uint32_t> tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27};
+    VReg v_tmp0 = v20;
+    ZReg z_tmp0 = z20;
+    ZReg z_tmp1 = z21;
+    ZReg z_tmp2 = z22;
+    ZReg z_tmp3 = z23;
+    ZReg z_tmp4 = z24;
+    ZReg z_tmp5 = z25;
+    ZReg z_tmp6 = z26;
+    ZReg z_tmp7 = z27;
+    VReg v_tmp7 = v27;
+
+    const std::vector<ZReg> z_tmp_vec
+            = {z_tmp0, z_tmp1, z_tmp2, z_tmp3, z_tmp4, z_tmp5, z_tmp6, z_tmp7};
+    constexpr static int z_tmp_vec_size = 8;
+};
+
+// Seperate class for no unroll/threading burden
+struct jit_single_blk_kernel_t : public jit_generator {
+    DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_single_blk_kernel)
+    static bool applicable(const prb_t &p) {
+        using namespace data_type;
+
+        bool ok = p.ndims >= 2 && mayiuse(sve_256)
+                && p.scale_type == scale_type_t::NONE
+                && utils::one_of(p.itype, f32) && utils::one_of(p.otype, f32)
+                && utils::everyone_is(0, p.ioff, p.ooff) && p.beta == 0.f
+                && prb_has_small_strides(p);
+        if (!ok) return false;
+
+        int64_t n0 = p.nodes[0].n;
+        auto i0 = p.nodes[0].is;
+        auto o0 = p.nodes[0].os;
+        int64_t n1 = p.nodes[1].n;
+        auto i1 = p.nodes[1].is;
+        auto o1 = p.nodes[1].os;
+
+        /*
+         * for a transpose of plain to 8c case, nodes would be like:
+         *     n    is   os
+         *     m    1    8
+         *     8    m    1
+         * or
+         *     8    m    1
+         *     m    1    8
+         */
+        ok = (utils::one_of(n0, 8, 16) || utils::one_of(n1, 8, 16))
+                && ((i0 == 1 && o1 == 1 && n0 == i1 && o0 == n1)
+                        || (o0 == 1 && i1 == 1 && n0 == o1 && i0 == n1));
+        if (!ok) return false;
+
+        // Do not handle transpose of dimensions other than last 2
+        for (int i = 2; i < p.ndims; ++i) {
+            if (p.nodes[i].is != p.nodes[i].os) {
+                ok = false;
+                break;
+            }
+        }
+
+        return ok;
+    }
+
+    jit_single_blk_kernel_t(const tr::prb_t &prb)
+        : jit_generator()
+        , prb_(prb)
+        , itype_sz_(data_type_size(prb_.itype))
+        , otype_sz_(data_type_size(prb_.otype))
+        , block_sz(prb.nodes[0].n) {}
+
+    void generate() override {
+        auto input_stride
+                = prb_.nodes[0].is != 1 ? prb_.nodes[0].is : prb_.nodes[1].is;
+        auto output_stride
+                = prb_.nodes[0].os != 1 ? prb_.nodes[0].os : prb_.nodes[1].os;
+
+        Label tail_processing;
+
+        const auto load_zp = [&](const ZRegS ymm_zp, const XReg reg_zp) {
+            dup(ymm_zp, WReg(reg_zp.getIdx()));
+            scvtf(ymm_zp, P_ALL_ONE / T_m, ymm_zp);
+        };
+
+        preamble();
+
+        if (prb_.req_src_zp) load_zp(ymm_src_zp, reg_src_zp);
+
+        if (prb_.req_dst_zp) load_zp(ymm_dst_zp, reg_dst_zp);
+
+        cmp(reg_ptr_tail, true);
+        b(EQ, tail_processing);
+
+        if (block_sz == 8) {
+            gen_ker8x8(0, 0, input_stride, output_stride, 8, 8);
+            block_sz = 8;
+        } else if (block_sz == 16) {
+            gen_ker16x16_in_8x8(input_stride, output_stride);
+            block_sz = 16;
+        } else {
+            assert(!"unimplemented");
+        }
+
+        postamble();
+
+        L(tail_processing);
+
+        if (block_sz == 8) {
+            auto i_tail = input_stride % 8 != 0 ? input_stride % 8 : 8;
+            auto o_tail = output_stride % 8 != 0 ? output_stride % 8 : 8;
+            if (i_tail != o_tail) {
+                auto t_mask = i_tail == 8 ? o_tail : i_tail;
+                gen_setmask(t_mask);
+                gen_ker8x8(0, 0, input_stride, output_stride, i_tail, o_tail);
+            }
+        } else if (block_sz == 16) {
+            auto i_tail = input_stride % 16 != 0 ? input_stride % 16 : 16;
+            auto o_tail = output_stride % 16 != 0 ? output_stride % 16 : 16;
+            if (i_tail != o_tail) {
+                auto t_mask = i_tail == 16 ? o_tail : i_tail;
+                t_mask %= 8;
+                if (t_mask != 0) gen_setmask(t_mask);
+                gen_ker16x16_in_8x8(
+                        input_stride, output_stride, i_tail, o_tail);
+            }
+        } else {
+            assert(!"unimplemented");
+        }
+
+        postamble();
+    }
+
+    void gen_loadu(const ZRegS ymm, const XReg &addr, int size) {
+        QReg xmm(ymm.getIdx());
+        switch (size) {
+            case 32: ld1w(ymm, p_lsb_256 / T_z, ptr(addr)); break;
+            case 16: ldr(xmm, ptr(addr)); break;
+            default: assert(!"unreachable");
+        }
+    }
+
+    void gen_storeu(const XReg &addr, const ZRegS ymm, int size) {
+        QReg xmm(ymm.getIdx());
+        switch (size) {
+            case 32: st1w(ymm, p_lsb_256, ptr(addr)); break;
+            case 16: str(xmm, ptr(addr)); break;
+            default: assert(!"unreachable");
+        }
+    }
+
+    void gen_maskloadu(
+            const ZRegS ymm, const XReg &addr, const PReg mask, int size) {
+        switch (size) {
+            case 32:
+            case 16: ld1w(ymm, mask / T_z, ptr(addr)); break;
+            default: assert(!"unreachable");
+        }
+    }
+
+    void gen_maskstoreu(
+            const XReg &addr, const ZRegS ymm, const PReg mask, int size) {
+        switch (size) {
+            case 32:
+            case 16: st1w(ymm, mask, ptr(addr)); break;
+            default: assert(!"unreachable");
+        }
+    }
+
+    // Register allocation xmm0~11
+    void gen_transpose_8x8() {
+        const uint64_t sveLen = get_sve_length();
+        constexpr int lane = 8;
+
+#if 0
+        /* Debug code
+	   z0:   7,  6,  5,  4,  3,  2,  1,  0
+	   z1:  15, 14, 13, 12, 11, 10,  9,  8
+	   ...
+	   z17: 63, 62, 61, 60, 59, 58, 57, 56
+	*/
+	ptrue(P_ALL_ONE.b);
+	ptrue(P_TMP.s, VL8);
+	not_(P_TMP.b, P_ALL_ONE/T_z, P_TMP.b);
+        index(z0.s, 0, 1);
+        mov(z0.s, P_TMP/T_m, 0);
+        mov(z_tmp_vec[0].s, 8);
+        mov(z_tmp_vec[0].s, P_TMP/T_m, 0);
+        for(uint32_t i=1; i<lane; i++)
+          add(ZRegS{i}, ZRegS{i-1}, z_tmp_vec[0].s);
+#endif
+
+        ptrue(P_TMP.s, VL4);
+
+        /* 1st turn */
+        for (uint32_t i = 0; i < lane / 2; i++) {
+            trn1(z_tmp_vec[i].s, ZRegS {2 * i}, ZRegS {2 * i + 1});
+            trn2(z_tmp_vec[lane / 2 + i].s, ZRegS {2 * i}, ZRegS {2 * i + 1});
+        }
+
+        /* 2nd turn */
+        trn1(z4.d, z_tmp_vec[0].d, z_tmp_vec[1].d);
+        trn1(z5.d, z_tmp_vec[4].d, z_tmp_vec[5].d);
+        trn2(z6.d, z_tmp_vec[0].d, z_tmp_vec[1].d);
+        trn2(z7.d, z_tmp_vec[4].d, z_tmp_vec[5].d);
+        trn1(z_tmp_vec[0].d, z_tmp_vec[2].d, z_tmp_vec[3].d);
+        trn1(z_tmp_vec[1].d, z_tmp_vec[6].d, z_tmp_vec[7].d);
+        trn2(z_tmp_vec[2].d, z_tmp_vec[2].d, z_tmp_vec[3].d);
+        trn2(z_tmp_vec[3].d, z_tmp_vec[6].d, z_tmp_vec[7].d);
+
+        /* 3rd turn */
+        for (uint32_t i = 0; i < lane / 2; i++) {
+            mov(ZRegD {i}, ZRegD {lane / 2 + i});
+            mov(z_tmp_vec[lane / 2 + i].d, z_tmp_vec[i].d);
+        }
+
+        /* 4th turn */
+        for (uint32_t i = 0; i < lane / 2; i++) {
+            ZRegB z {lane / 2 + i};
+            ZRegB z_tmp = z_tmp_vec[lane / 2 + i].b;
+            /* Move bit 0-127 to 128-255. */
+            ext(z, z, 16);
+            /* Move bit 128-255 to 0-127. */
+            ext(z_tmp, z_tmp, sveLen - 16);
+        }
+
+        /* 5th turn */
+        for (uint32_t i = 0; i < lane / 2; i++) {
+            ZRegS z0 {i};
+            ZRegS z1 {lane / 2 + i};
+            sel(z0, P_TMP, z0, z_tmp_vec[lane / 2 + i].s);
+            sel(z1, P_TMP, z1, z_tmp_vec[i].s);
+        }
+    }
+
+    // keep order nchw -> nChw()C
+    // or nChw()C -> nchw
+    void gen_setmask(int mask) {
+        mov_imm(x_tmp_0, 0);
+        mov_imm(x_tmp_1, mask);
+        whilelt(p_mask.s, x_tmp_0, x_tmp_1);
+    }
+
+    // TODO: Mark parameter with type information
+    // XXX: !
+    // offset in byte offset
+    // stride in element number
+    //
+    // Gen specific 8x8 transform respect to certain tail condition
+    void gen_tr8x8(int i_off, int o_off, int input_stride, int output_stride,
+            int in_tail, int out_tail) {
+        constexpr int lane = 8;
+
+        if (in_tail == 0 || out_tail == 0) return;
+
+        for (int i = 0; i < out_tail; ++i) {
+            if (in_tail != lane) {
+                add_imm(x_addr, reg_ptr_in_,
+                        i_off + i * input_stride * itype_sz_, x_tmp_0);
+                gen_maskloadu(ZRegS(i), x_addr, p_mask, lane * itype_sz_);
+            } else {
+                add_imm(x_addr, reg_ptr_in_,
+                        i_off + i * input_stride * itype_sz_, x_tmp_0);
+                gen_loadu(ZRegS(i), x_addr, lane * itype_sz_);
+            }
+            if (prb_.req_src_zp) { fsub(ZRegS(i), ZRegS(i), ymm_src_zp); }
+        }
+
+        gen_transpose_8x8();
+
+        for (int i = 0; i < in_tail; ++i) {
+            if (prb_.req_dst_zp) { fadd(ZRegS(i), ZRegS(i), ymm_dst_zp); }
+            if (out_tail == lane) {
+                add_imm(x_addr, reg_ptr_out_,
+                        o_off + i * output_stride * otype_sz_, x_tmp_0);
+                gen_storeu(x_addr, ZRegS(i), lane * otype_sz_);
+            } else {
+                add_imm(x_addr, reg_ptr_out_,
+                        o_off + i * output_stride * otype_sz_, x_tmp_0);
+                gen_maskstoreu(x_addr, ZRegS(i), p_mask, lane * otype_sz_);
+            }
+        }
+    }
+
+    // tail: 0 ~ 8
+    // support: either in_tail or out_tail is not 8, but not both
+    void gen_ker8x8(int i_off, int o_off, int input_stride, int output_stride,
+            int in_tail, int out_tail) {
+        gen_tr8x8(i_off, o_off, input_stride, output_stride, in_tail, out_tail);
+    }
+
+    void gen_ker16x16_in_8x8(int input_stride, int output_stride) {
+        const auto lane = 16;
+        const auto sub_lane = lane / 2;
+        gen_tr8x8(0, 0, input_stride, output_stride, sub_lane, sub_lane);
+        gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_,
+                input_stride, output_stride, sub_lane, sub_lane);
+        gen_tr8x8(sub_lane * itype_sz_, output_stride * sub_lane * otype_sz_,
+                input_stride, output_stride, sub_lane, sub_lane);
+        gen_tr8x8((input_stride * sub_lane + sub_lane) * itype_sz_,
+                (output_stride * sub_lane + sub_lane) * otype_sz_, input_stride,
+                output_stride, sub_lane, sub_lane);
+    }
+
+    // tail can be 1 ~ 16, using avx2 for now
+    void gen_ker16x16_in_8x8(
+            int input_stride, int output_stride, int in_tail, int out_tail) {
+        constexpr auto lane = 16;
+        constexpr auto sub_lane = lane / 2;
+        auto tail = in_tail != lane ? in_tail : out_tail;
+
+        const auto l_tail = tail < sub_lane ? tail : sub_lane;
+        const auto u_tail = tail < sub_lane ? 0 : tail - sub_lane;
+
+        if (tail == in_tail) {
+            gen_tr8x8(0, 0, input_stride, output_stride, l_tail, sub_lane);
+            gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_,
+                    input_stride, output_stride, l_tail, sub_lane);
+            gen_tr8x8(sub_lane * itype_sz_,
+                    output_stride * sub_lane * otype_sz_, input_stride,
+                    output_stride, u_tail, sub_lane);
+            gen_tr8x8(itype_sz_ * (input_stride * sub_lane + sub_lane),
+                    otype_sz_ * (output_stride * sub_lane + sub_lane),
+                    input_stride, output_stride, u_tail, sub_lane);
+        } else {
+            gen_tr8x8(0, 0, input_stride, output_stride, sub_lane, l_tail);
+            gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_,
+                    input_stride, output_stride, sub_lane, u_tail);
+            gen_tr8x8(sub_lane * itype_sz_,
+                    output_stride * sub_lane * itype_sz_, input_stride,
+                    output_stride, sub_lane, l_tail);
+            gen_tr8x8(itype_sz_ * (input_stride * sub_lane + sub_lane),
+                    otype_sz_ * (output_stride * sub_lane + sub_lane),
+                    input_stride, output_stride, sub_lane, u_tail);
+        }
+    }
+
+private:
+    // 6 ~ 12
+    constexpr static int xmm_save_for_windows = 0;
+    constexpr static int xmm_save_start_from = 6;
+    constexpr static int xmm_width = 16;
+
+    void preamble() { ptrue(p_lsb_256.b, VL32); }
+
+    void postamble() { ret(); }
+
+    const prb_t &prb_;
+
+    int itype_sz_;
+    int otype_sz_;
+    int block_sz;
+
+    XReg reg_ptr_in_ = abi_param1;
+    XReg reg_ptr_out_ = abi_param2;
+    XReg reg_ptr_tail = abi_param3;
+    XReg reg_src_zp = abi_param4;
+    XReg reg_dst_zp = abi_param5;
+
+    XReg x_addr = x10;
+    XReg x_tmp_0 = x11;
+    XReg x_tmp_1 = x12;
+
+    /* Avoid P_TMP(p7) in jit_generator.hpp. */
+    PReg p_lsb_256 = p6;
+    PReg p_mask = p5;
+
+    ZRegS ymm_tmp = z0.s;
+    ZRegS ymm_src_zp = z14.s;
+    ZRegS ymm_dst_zp = z15.s;
+
+    const std::vector<uint32_t> tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27};
+    VReg v_tmp0 = v20;
     ZReg z_tmp0 = z20;
     ZReg z_tmp1 = z21;
     ZReg z_tmp2 = z22;
@@ -1472,15 +2300,31 @@ kernel_t *kernel_t::create(const kernel_t::desc_t &desc) {
 
     return nullptr;
 }
+
 } // namespace tr
 
 static void prb_block_for_cache(tr::prb_t &prb) {
     /* If strides for 0th and 1st nodes are cache friendly
      * then one can altogether do away with blocking ! */
-    const bool cache_blocking_needed = false
-            || (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16)
-            || (prb.ndims > 1 && prb.nodes[1].is % 64 == 0
-                    && prb.nodes[1].n > 16);
+    static constexpr int num_elems_thr = 16;
+    const bool stride_cache_friendly
+            = ((prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > num_elems_thr)
+                      || (prb.ndims > 1 && prb.nodes[1].is % num_elems_thr == 0
+                              && prb.nodes[1].n > num_elems_thr))
+            && !prb.is_tail_present;
+
+    // performance improvement for shapes with large inner-most dimension
+    const size_t L1_cache_sz
+            = size_t(3) * platform::get_per_core_cache_size(1) / 4;
+    const size_t itype_sz_ = data_type_size(prb.itype);
+    const size_t inner_block_sz = prb.nodes[0].n * itype_sz_;
+    const bool requires_inner_blocking = inner_block_sz > L1_cache_sz
+            // 'is_tail_present' is not supported for cache_blocking when
+            // asymmetric_comp is executed.
+            && IMPLICATION(prb.req_asymmetric_comp, !prb.is_tail_present);
+
+    const bool cache_blocking_needed
+            = stride_cache_friendly || requires_inner_blocking;
     if (!cache_blocking_needed) return;
 
     int unit_input_stride_idx = -1;
@@ -1496,28 +2340,58 @@ static void prb_block_for_cache(tr::prb_t &prb) {
         const auto output_stride = prb.nodes[unit_input_stride_idx].os;
         const auto num_elems = prb.nodes[unit_input_stride_idx].n;
 
-        const bool split_needed = (num_elems > 16) && (num_elems % 16 == 0);
+        const bool split_needed = (num_elems > num_elems_thr)
+                && (num_elems % num_elems_thr == 0);
         const int move_location = (output_stride % 4 != 0) ? 0 : 1;
-        if (split_needed) prb_node_split(prb, unit_input_stride_idx, 16);
+        if (split_needed)
+            prb_node_split(prb, unit_input_stride_idx, num_elems_thr);
 
         /* Because of cache-unfriendly nature of unit-output stride node, let
          * us move unit-input stride node on or near front! */
-        prb_node_move(prb, unit_input_stride_idx, move_location);
+        if (unit_input_stride_idx != move_location)
+            prb_node_move(prb, unit_input_stride_idx, move_location);
     }
 
     /* Potentially, split the node with os=1 in two and pull in the node with
      * is=1 between them for better cache reuse:
      * [n0:is0:1][n1:1:os1] --> [16n0:is0:1][n1:1:os1][n0/16:is0*16:16] */
     if (prb.ndims >= 2 && prb.nodes[0].os == 1 && prb.nodes[1].is == 1) {
-        const auto input_stride = prb.nodes[0].is;
         const auto num_elems = prb.nodes[0].n;
 
-        const bool split_needed = true && (num_elems > 16)
-                && (num_elems % 16 == 0) && (input_stride >= 256)
-                && (input_stride % 64 == 0);
+        const bool split_needed = (num_elems > num_elems_thr)
+                && (num_elems % num_elems_thr == 0);
         if (split_needed) {
-            prb_node_split(prb, 0, 16);
+            prb_node_split(prb, 0, num_elems_thr);
             prb_node_move(prb, 1, 2);
+
+            // Update node information
+            prb_node_dependency(prb);
+
+            // heuristics - looping over the unrolled dims should maximize reuse
+            // of the already cached data; observation is choosing the smallest
+            // dim from the remaining (from 2 up to ndims) gives good results
+            constexpr int new_position = 2;
+            const auto dim_beg_it = std::begin(prb.nodes);
+            const auto dim_two_it = dim_beg_it + new_position;
+            const auto dim_last_it = dim_beg_it + prb.ndims;
+            const auto min_n_node_it = std::min_element(dim_two_it, dim_last_it,
+                    [](const tr::node_t &lhs, const tr::node_t &rhs) {
+                        return lhs.n < rhs.n;
+                    });
+            const auto min_idx = std::distance(dim_beg_it, min_n_node_it);
+            // check if min_idx node is parent of node with tail processing which
+            // is currently unsupported (i.e. tail processing can only be handled
+            // at the inner-most dimension)
+            bool inner_block_has_tail = false;
+            for (int idx = min_idx - 1; idx >= new_position; idx--) {
+                if (prb.nodes[idx].parent_node_id == min_idx) {
+                    inner_block_has_tail = true;
+                    break;
+                }
+            }
+
+            if (min_idx > new_position && (!inner_block_has_tail))
+                prb_node_move(prb, min_idx, new_position);
         }
     }
 }
@@ -1527,73 +2401,76 @@ static void prb_block_for_cache(tr::prb_t &prb) {
  * parallel driver and the kernel. */
 static void prb_thread_kernel_balance(
         tr::prb_t &prb, int &ndims_ker_max, int nthr) {
-    size_t sz_total = 1;
+    size_t size_total = 1;
     for (int d = 0; d < prb.ndims; ++d)
-        sz_total *= prb.nodes[d].n;
+        size_total *= prb.nodes[d].n;
 
-    /* The general expression for sz_drv_thr can be written as
-     * sz_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1)
+    /* The general expression for size_drv_thr can be written as
+     * size_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1)
      * where FC and VC are fixed and variable costs respectively.
      * Though for now, the below heuristic seems to be good enough */
-    const size_t sz_drv_thr = (nthr > 1) ? 16 * nthr : 1;
+    const size_t size_drv_thr = (nthr > 1) ? 16 * nthr : 1;
 
-    /* sz_drv_min is the minimal size for the parallel
+    /* size_drv_min is the minimal size for the parallel
      * driver required for good parallelization */
-    const size_t sz_drv_min
-            = nstl::min<size_t>(sz_drv_thr, utils::div_up(sz_total, 1024));
+    const size_t size_drv_min
+            = nstl::min<size_t>(size_drv_thr, utils::div_up(size_total, 1024));
 
     /* kdims -- # of dimensions processed by a kernel
-     * sz_ker_cur -- product of the dimension processed by a kernel
-     * sz_drv_cur -- product of the dimension processed by a driver */
+     * size_ker_cur -- product of the dimension processed by a kernel
+     * size_drv_cur -- product of the dimension processed by a driver */
 
     int kdims = prb.ndims;
-    size_t sz_drv_cur = 1;
-    for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims)
-        sz_drv_cur *= prb.nodes[kdims - 1].n;
+    size_t size_drv_cur = 1;
+    for (; kdims > 1 && size_drv_cur < size_drv_min; --kdims)
+        size_drv_cur *= prb.nodes[kdims - 1].n;
 
-    size_t sz_ker_cur = 1;
+    size_t size_ker_cur = 1;
     for (int d = 0; d < kdims; ++d)
-        sz_ker_cur *= prb.nodes[d].n;
+        size_ker_cur *= prb.nodes[d].n;
 
-    /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min.
+    /* Initially kdims is chosen so that size_drv_cur >= size_drv_min.
      *
-     * It might happen that for chosen kdims the sz_ker_cur is too small
+     * It might happen that for chosen kdims the size_ker_cur is too small
      * (less than tr::ker_prb_size_min). In that case try to split the
-     * innermost driver dimension into two, to increase sz_ker_cur. */
-    bool want_borrow_ker_from_drv = true && kdims < prb.ndims
-            && sz_ker_cur < tr::ker_prb_size_min && sz_drv_cur > sz_drv_min
-            && kdims != prb.blk_chunk_idx;
+     * innermost driver dimension into two, to increase size_ker_cur. */
+    const bool want_borrow_ker_from_drv = kdims < prb.ndims
+            && size_ker_cur < tr::ker_prb_size_min
+            && size_drv_cur > size_drv_min;
     if (want_borrow_ker_from_drv) {
-        /* sz_want_borrow is the minimal sz, so that:
-         *  o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min
+        /* size_want_borrow is the minimal size, so that:
+         *  o) size_ker_cur * size_want_borrow >= tr::ker_prb_size_min
          *  o) current innermost driver dimension is divisible by
-         *     sz_want_borrow (so that we can evenly split that
+         *     size_want_borrow (so that we can evenly split that
          *     dimension into two)
          *
-         *  In the worst case the minimal sz_want_borrow is equal
+         *  In the worst case the minimal size_want_borrow is equal
          *  to the innermost driver dimension itself. In that case
          *  we will sacrifice it in favor of kernel (is it fine?). */
-        size_t sz_want_borrow = utils::div_up(tr::ker_prb_size_min, sz_ker_cur);
-        for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow)
+        size_t size_want_borrow
+                = utils::div_up(tr::ker_prb_size_min, size_ker_cur);
+        for (; prb.nodes[kdims].n % size_want_borrow; ++size_want_borrow)
             ;
-        if (sz_want_borrow != prb.nodes[kdims].n)
-            prb_node_split(prb, kdims, sz_want_borrow);
+
+        if (size_want_borrow != prb.nodes[kdims].n)
+            prb_node_split(prb, kdims, size_want_borrow);
         kdims += 1;
     }
 
     /* On the other hand it might happen that for chosen kdims
-     * the sz_drv_cur is too small (less than sz_drv_min). In that case
+     * the size_drv_cur is too small (less than size_drv_min). In that case
      * try to split the outermost kernel dimension into two, to increase
-     * sz_drv_cur. */
-    bool want_borrow_drv_from_ker = true && sz_ker_cur > tr::ker_prb_size_min
-            && sz_drv_cur < sz_drv_min && kdims != prb.blk_chunk_idx;
+     * size_drv_cur. */
+    const bool want_borrow_drv_from_ker = size_ker_cur > tr::ker_prb_size_min
+            && size_drv_cur < size_drv_min;
     if (want_borrow_drv_from_ker) {
-        size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur);
-        for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow)
+        size_t size_want_borrow = utils::div_up(size_drv_min, size_drv_cur);
+        for (; prb.nodes[kdims - 1].n % size_want_borrow; ++size_want_borrow)
             ;
-        if (sz_want_borrow != prb.nodes[kdims - 1].n)
+
+        if (size_want_borrow != prb.nodes[kdims - 1].n)
             prb_node_split(
-                    prb, kdims - 1, prb.nodes[kdims - 1].n / sz_want_borrow);
+                    prb, kdims - 1, prb.nodes[kdims - 1].n / size_want_borrow);
     }
 
     ndims_ker_max = kdims;
@@ -1607,6 +2484,33 @@ static void prb_thread_kernel_balance(
     }
 }
 
+status_t jit_uni_reorder_t::pd_t::init(
+        engine_t *engine, engine_t *src_engine, engine_t *dst_engine) {
+    CHECK(cpu_reorder_pd_t::init(engine, src_engine, dst_engine));
+
+    const bool compensation_needed
+            = prb_.req_s8s8_comp || prb_.req_asymmetric_comp;
+    if (compensation_needed) init_scratchpad();
+
+    return status::success;
+}
+
+void jit_uni_reorder_t::pd_t::init_scratchpad() {
+    const memory_desc_wrapper od(dst_md());
+    const auto G = with_groups_ ? od.padded_dims()[0] : 1;
+    const auto N = od.padded_dims()[with_groups_ ? 1 : 0];
+    static constexpr int cache_line_size = 16;
+    const auto wspace_per_thr_size
+            = utils::rnd_up(G * N, cache_line_size) * sizeof(int32_t);
+
+    auto scratchpad = scratchpad_registry().registrar();
+    const auto compensation_reduce_size = wspace_per_thr_size * nthr_;
+
+    // Every thread gets its own scratchpad space for each N
+    scratchpad.template book<int32_t>(memory_tracking::names::key_reorder_space,
+            compensation_reduce_size);
+}
+
 status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd,
         engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine,
         const memory_desc_t *src_md, engine_t *dst_engine,
@@ -1616,36 +2520,18 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd,
     status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr);
     if (prb_init_status != status::success) return prb_init_status;
 
-    DEBUG({
-        printf("init : ");
-        prb_dump(prb);
-    });
-    // Sort the prb array in increasing sizes of the output stride
-    prb_normalize(prb);
-    DEBUG({
-        printf("norm : ");
-        prb_dump(prb);
-    });
-    /* Combine the variables, which appear together on both
-             * sides of the reorder */
-    prb_simplify(prb);
-    DEBUG({
-        printf("smpl : ");
-        prb_dump(prb);
-    });
-
     prb_block_for_cache(prb);
     DEBUG({
         printf("cache: ");
         prb_dump(prb);
     });
 
-    CHECK(prb_check_blk(prb, *dst_md));
-
-    int ndims_ker_max;
+    int ndims_ker_max {};
     int nthr = dnnl_get_max_threads();
     prb_thread_kernel_balance(prb, ndims_ker_max, nthr);
 
+    if (prb.is_tail_present) prb_node_dependency(prb);
+
     tr::kernel_t::desc_t ker_desc;
     status_t ker_init_status
             = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max);
@@ -1663,99 +2549,191 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd,
     auto _pd = new pd_t(
             attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md);
     if (_pd == nullptr) return status::out_of_memory;
+
+    _pd->nthr_ = nthr;
+    _pd->prb_ = prb;
+    _pd->with_groups_
+            = prb.compensation_mask == tr::prb_t::comp_mask_with_groups;
     if (_pd->init(engine, src_engine, dst_engine) != status::success) {
         delete _pd;
         return status::unimplemented;
     }
-    _pd->prb_ = prb;
     _pd->ker_desc_ = ker_desc;
     _pd->init_scratchpad_md();
-    _pd->nthr_ = nthr;
+
     return safe_ptr_assign(*reorder_pd, _pd);
 }
 
-void jit_uni_reorder_t::omp_driver_0d(
-        int off, const char *in, char *out, const float *scale) const {
-    tr::call_param_t c {in, out, scale, 0};
-    (*kernel_)(&c);
+void jit_uni_reorder_t::omp_driver_0d(int off, const char *in, char *out,
+        const float *scale, int src_zp, int dst_zp,
+        int32_t *compensation_scratch) const {
+    const tr::prb_t &prb = pd()->prb_;
+
+    tr::call_param_t base_params;
+    base_params.in = in;
+    base_params.out = out;
+    base_params.scale = scale;
+    base_params.src_zp = src_zp;
+    base_params.dst_zp = dst_zp;
+    base_params.compensation_scratch = compensation_scratch;
+
+    if (prb.is_tail_present) {
+        tr::tail_call_param_t tail_params;
+        tail_params.base_params = base_params;
+
+        static constexpr int omp_ndims = 0;
+        fill_curr_data_chunks(prb, off, nullptr, omp_ndims, tail_params);
+        (*kernel_)(&tail_params);
+    } else {
+        (*kernel_)(&base_params);
+    }
 }
 
 void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off,
-        const char *in, char *out, const float *scale) const {
-    const tr::node_t *ns = pd()->prb_.nodes + off;
+        const char *in, char *out, const float *scale, int src_zp, int dst_zp,
+        int32_t *compensation_scratch) const {
+    const tr::prb_t &prb = pd()->prb_;
+    const tr::node_t *ns = prb.nodes + off;
     for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) {
-        auto c = tr::call_param_t();
-        c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype);
-        c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype);
-        c.scale = scale + d0 * ns[0].ss;
-        c.blk_chunks = d0;
-        (*kernel_)(&c);
+        tr::call_param_t base_params;
+        base_params.in = in + d0 * ns[0].is * data_type_size(prb.itype);
+        base_params.out = out + d0 * ns[0].os * data_type_size(prb.otype);
+        base_params.scale = scale + d0 * ns[0].ss;
+        base_params.src_zp = src_zp;
+        base_params.dst_zp = dst_zp;
+        base_params.compensation_scratch = compensation_scratch + d0 * ns[0].cs;
+
+        if (prb.is_tail_present) {
+            tr::tail_call_param_t tail_params;
+            tail_params.base_params = base_params;
+
+            static constexpr int omp_ndims = 1;
+            const ptrdiff_t omp_data_chunks[omp_ndims] = {d0};
+            fill_curr_data_chunks(
+                    prb, off, omp_data_chunks, omp_ndims, tail_params);
+            (*kernel_)(&tail_params);
+        } else {
+            (*kernel_)(&base_params);
+        }
     });
 }
 
 void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off,
-        const char *in, char *out, const float *scale) const {
-    const tr::node_t *ns = pd()->prb_.nodes + off;
-    const int blk_idx_off = pd()->prb_.blk_chunk_idx - off;
+        const char *in, char *out, const float *scale, int src_zp, int dst_zp,
+        int32_t *compensation_scratch) const {
+    const tr::prb_t &prb = pd()->prb_;
+    const tr::node_t *ns = prb.nodes + off;
     for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
             [&](ptrdiff_t d1, ptrdiff_t d0) {
-                auto c = tr::call_param_t();
-                c.in = in
+                tr::call_param_t base_params;
+                base_params.in = in
                         + (d0 * ns[0].is + d1 * ns[1].is)
-                                * data_type_size(pd()->prb_.itype);
-                c.out = out
+                                * data_type_size(prb.itype);
+                base_params.out = out
                         + (d0 * ns[0].os + d1 * ns[1].os)
-                                * data_type_size(pd()->prb_.otype);
-                c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss;
-                c.blk_chunks = utils::pick(blk_idx_off, d0, d1);
-                (*kernel_)(&c);
+                                * data_type_size(prb.otype);
+                base_params.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss;
+                base_params.src_zp = src_zp;
+                base_params.dst_zp = dst_zp;
+                base_params.compensation_scratch
+                        = compensation_scratch + d0 * ns[0].cs + d1 * ns[1].cs;
+
+                if (prb.is_tail_present) {
+                    tr::tail_call_param_t tail_params;
+                    tail_params.base_params = base_params;
+
+                    static constexpr int omp_ndims = 2;
+                    const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1};
+                    fill_curr_data_chunks(
+                            prb, off, omp_data_chunks, omp_ndims, tail_params);
+
+                    (*kernel_)(&tail_params);
+                } else {
+                    (*kernel_)(&base_params);
+                }
             });
 }
 
 void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off,
-        const char *in, char *out, const float *scale) const {
-    const tr::node_t *ns = pd()->prb_.nodes + off;
-    const int blk_idx_off = pd()->prb_.blk_chunk_idx - off;
+        const char *in, char *out, const float *scale, int src_zp, int dst_zp,
+        int32_t *compensation_scratch) const {
+    const tr::prb_t &prb = pd()->prb_;
+    const tr::node_t *ns = prb.nodes + off;
     for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n,
             (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
-                auto c = tr::call_param_t();
-                c.in = in
+                tr::call_param_t base_params;
+                base_params.in = in
                         + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is)
-                                * data_type_size(pd()->prb_.itype);
-                c.out = out
+                                * data_type_size(prb.itype);
+                base_params.out = out
                         + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os)
-                                * data_type_size(pd()->prb_.otype);
-                c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss;
-                c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2);
-                (*kernel_)(&c);
+                                * data_type_size(prb.otype);
+                base_params.scale
+                        = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss;
+                base_params.src_zp = src_zp;
+                base_params.dst_zp = dst_zp;
+                base_params.compensation_scratch = compensation_scratch
+                        + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs;
+
+                if (prb.is_tail_present) {
+                    tr::tail_call_param_t tail_params;
+                    tail_params.base_params = base_params;
+
+                    static constexpr int omp_ndims = 3;
+                    const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1, d2};
+                    fill_curr_data_chunks(
+                            prb, off, omp_data_chunks, omp_ndims, tail_params);
+                    (*kernel_)(&tail_params);
+                } else {
+                    (*kernel_)(&base_params);
+                }
             });
 }
 
 void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off,
-        const char *in, char *out, const float *scale) const {
-    const tr::node_t *ns = pd()->prb_.nodes + off;
-    const int blk_idx_off = pd()->prb_.blk_chunk_idx - off;
+        const char *in, char *out, const float *scale, int src_zp, int dst_zp,
+        int32_t *compensation_scratch) const {
+    const tr::prb_t &prb = pd()->prb_;
+    const tr::node_t *ns = prb.nodes + off;
     for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n,
             (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
             [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
-                auto c = tr::call_param_t();
-                c.in = in
+                tr::call_param_t base_params;
+                base_params.in = in
                         + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is
                                   + d3 * ns[3].is)
-                                * data_type_size(pd()->prb_.itype);
-                c.out = out
+                                * data_type_size(prb.itype);
+                base_params.out = out
                         + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os
                                   + d3 * ns[3].os)
-                                * data_type_size(pd()->prb_.otype);
-                c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss
-                        + d3 * ns[3].ss;
-                c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2, d3);
-                (*kernel_)(&c);
+                                * data_type_size(prb.otype);
+                base_params.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss
+                        + d2 * ns[2].ss + d3 * ns[3].ss;
+                base_params.src_zp = src_zp;
+                base_params.dst_zp = dst_zp;
+                base_params.compensation_scratch = compensation_scratch
+                        + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs
+                        + d3 * ns[3].cs;
+
+                if (prb.is_tail_present) {
+                    tr::tail_call_param_t tail_params;
+                    tail_params.base_params = base_params;
+
+                    static constexpr int omp_ndims = 4;
+                    const ptrdiff_t omp_data_chunks[omp_ndims]
+                            = {d0, d1, d2, d3};
+                    fill_curr_data_chunks(
+                            prb, off, omp_data_chunks, omp_ndims, tail_params);
+                    (*kernel_)(&tail_params);
+                } else {
+                    (*kernel_)(&base_params);
+                }
             });
 }
 
-void jit_uni_reorder_t::omp_driver(
-        const char *in, char *out, const float *scale) const {
+void jit_uni_reorder_t::omp_driver(const char *in, char *out,
+        const float *scale, int src_zp, int dst_zp,
+        const memory_tracking::grantor_t &scratchpad) const {
     in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype);
     out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype);
 
@@ -1770,29 +2748,153 @@ void jit_uni_reorder_t::omp_driver(
 
     int ndims = pd()->prb_.ndims;
     int ndims_ker = pd()->ker_desc_.prb.ndims;
+    const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp;
+    const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp;
+    const bool req_compensation = req_s8s8_comp || req_asymmetric_comp;
     assert(ndims - ndims_ker <= ndims_driver_max);
 
+    int32_t *compensation_reduce_scratch = scratchpad.template get<int32_t>(
+            memory_tracking::names::key_reorder_space);
+
+    const memory_desc_wrapper od(pd()->dst_md());
+    const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1;
+    const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0];
+    static constexpr int cache_line_size = 16;
+    const auto wspace_per_thr_size = utils::rnd_up(G * N, cache_line_size);
+    const auto wspace_per_thr_bytes = wspace_per_thr_size * sizeof(int32_t);
+
     if (ndims - ndims_ker == 0) {
-        omp_driver_0d(ndims_ker, in, out, scale);
+        if (req_compensation)
+            std::memset(compensation_reduce_scratch, 0, wspace_per_thr_bytes);
+
+        omp_driver_0d(ndims_ker, in, out, scale, src_zp, dst_zp,
+                compensation_reduce_scratch);
     } else {
         parallel(pd()->nthr_, [&](const int ithr, const int nthr) {
+            int32_t *compensation_scratch = nullptr;
+            if (req_compensation) {
+                compensation_scratch = &compensation_reduce_scratch[ithr
+                        * wspace_per_thr_size];
+                std::memset(compensation_scratch, 0, wspace_per_thr_bytes);
+            }
+
             switch (ndims - ndims_ker) {
                 case 1:
-                    omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale);
+                    omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale, src_zp,
+                            dst_zp, compensation_scratch);
                     break;
                 case 2:
-                    omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale);
+                    omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale, src_zp,
+                            dst_zp, compensation_scratch);
                     break;
                 case 3:
-                    omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale);
+                    omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale, src_zp,
+                            dst_zp, compensation_scratch);
                     break;
                 case 4:
-                    omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale);
+                    omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale, src_zp,
+                            dst_zp, compensation_scratch);
                     break;
                 default: assert(!"unimplemented");
             }
         });
     }
+
+    // Reduction of intermediate compensation results to the final output
+    if (req_compensation) {
+        const int nthr = ndims - ndims_ker == 0 ? 1 : pd()->nthr_;
+        reduce_compensation(
+                out, compensation_reduce_scratch, nthr, wspace_per_thr_size);
+    }
+}
+
+void jit_uni_reorder_t::reduce_compensation(char *out,
+        const int32_t *compensation_reduce_scratch, const int nthr,
+        const dim_t wspace_per_thr_size) const {
+
+    const memory_desc_wrapper od(pd()->dst_md());
+    const size_t offset = od.size() - od.additional_buffer_size();
+
+    static constexpr auto comp_dt_size = sizeof(int32_t);
+    static constexpr int32_t comp_s8s8_shift = 128;
+
+    // Note: We do not need to explicitly zero-out compensation buffer, as the
+    // per_thread buffers are already zeroed out in the padded area.
+    const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1;
+    const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0];
+    const auto GN = G * N;
+    const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp;
+    const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp;
+    const size_t zp_offset
+            = offset + (pd()->prb_.req_s8s8_comp ? GN * comp_dt_size : 0);
+
+    parallel_nd(GN, [&](int idx) {
+        int32_t acc = 0;
+        for (int ithr = 0; ithr < nthr; ithr++) {
+            acc -= compensation_reduce_scratch[ithr * wspace_per_thr_size
+                    + idx];
+        }
+        if (req_s8s8_comp) {
+            int32_t *out_comp = reinterpret_cast<int32_t *>(&out[offset]);
+            out_comp[idx] = comp_s8s8_shift * acc;
+        }
+        if (req_asymmetric_comp) {
+            int32_t *out_asym_comp
+                    = reinterpret_cast<int32_t *>(&out[zp_offset]);
+            out_asym_comp[idx] = acc;
+        }
+    });
+}
+
+void jit_uni_reorder_t::fill_curr_data_chunks(const tr::prb_t &prb,
+        const int off, const ptrdiff_t *omp_data_chunks, const int omp_ndims,
+        tr::tail_call_param_t &c) const {
+    // Chunks are backwards numered i.e:
+    // [0] -> [node_size]
+    // [1] -> [node_size - 1]
+    // ...
+    // [node_size - 1] -> [1]
+
+    // It is done like this, because it is easier to decrement counter
+    // and check if it is equal to zero than increment and check
+    // if it is equal to node_size in jit kernel.
+
+    static constexpr int64_t empty_chunk_info = -1;
+    static constexpr int64_t last_chunk = 1;
+
+    for (int curr_node_id = prb.ndims - 1; curr_node_id >= 0; curr_node_id--) {
+        const int parent_node_id = prb.nodes[curr_node_id].parent_node_id;
+        const bool is_drv_processing_this_node
+                = curr_node_id >= off && curr_node_id <= off + omp_ndims - 1;
+        const bool is_tail_processing
+                = prb.is_tail_in_one_of_child_nodes(curr_node_id)
+                || prb.nodes[curr_node_id].tail_size > 0;
+
+        if (is_drv_processing_this_node && is_tail_processing) {
+            const int inner_idx = curr_node_id - off;
+            assert(inner_idx < omp_ndims);
+            const int64_t node_size = prb.nodes[curr_node_id].tail_size > 0
+                    ? prb.nodes[curr_node_id].tail_size
+                    : prb.nodes[curr_node_id].n;
+            const int64_t data_chunk = node_size - omp_data_chunks[inner_idx];
+
+            if (!prb.nodes[curr_node_id].is_parent_empty()) {
+                const bool is_parent_chunk_last
+                        = c.curr_data_chunks[parent_node_id] == last_chunk;
+                c.curr_data_chunks[curr_node_id]
+                        = is_parent_chunk_last ? data_chunk : empty_chunk_info;
+                c.zeroing_data = static_cast<int64_t>(
+                        is_parent_chunk_last && data_chunk <= 0);
+            } else {
+                c.curr_data_chunks[curr_node_id] = data_chunk;
+                c.zeroing_data = static_cast<int64_t>(data_chunk <= 0);
+            }
+            c.skip_kernel_execution = static_cast<int64_t>(c.zeroing_data
+                    && !prb.nodes[curr_node_id].is_zero_pad_needed);
+            if (c.zeroing_data || c.skip_kernel_execution) break;
+        } else
+            c.curr_data_chunks[curr_node_id] = empty_chunk_info;
+    }
 }
 
 status_t jit_uni_reorder_t::init(engine_t *engine) {
@@ -1801,13 +2903,98 @@ status_t jit_uni_reorder_t::init(engine_t *engine) {
 }
 
 status_t jit_uni_reorder_t::execute(const exec_ctx_t &ctx) const {
-    status_t status = status::success;
     auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM);
-    auto out = CTX_OUT_CLEAN_MEM(char *, DNNL_ARG_TO, status);
-    CHECK(status);
+    auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO);
     DEFINE_SCALES_BUFFER(scales);
+    DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM);
+    DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO);
+    const auto &scratchpad = ctx.get_scratchpad_grantor();
+
+    omp_driver(in, out, scales, src_zp, dst_zp, scratchpad);
+
+    return status::success;
+}
+
+status_t jit_blk_reorder_t::pd_t::create(reorder_pd_t **reorder_pd,
+        engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine,
+        const memory_desc_t *src_md, engine_t *dst_engine,
+        const memory_desc_t *dst_md) {
+    auto prb = tr::prb_t();
+
+    status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr);
+    if (prb_init_status != status::success) return prb_init_status;
+    // only uni_reorder supports tail processing now
+    // TODO: Add tail processing support in blk_reorder
+    if (prb.is_tail_present) return status::unimplemented;
+
+    prb_tile_normalize(prb);
+    DEBUG({
+        printf("tile : ");
+        prb_dump(prb);
+    });
+
+    if (!tr::jit_single_blk_kernel_t::applicable(prb)) {
+        return status::unimplemented;
+    }
 
-    omp_driver(in, out, scales);
+    auto _pd = new pd_t(
+            attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md);
+    if (_pd == nullptr) return status::out_of_memory;
+    _pd->prb_ = prb;
+    if (_pd->init(engine, src_engine, dst_engine) != status::success) {
+        delete _pd;
+        return status::unimplemented;
+    }
+    _pd->init_scratchpad_md();
+
+    return safe_ptr_assign(*reorder_pd, _pd);
+}
+
+void jit_blk_reorder_t::pd_t::prb_tile_normalize(tr::prb_t &p) {
+    if (!utils::one_of(p.nodes[0].n, 8ul, 16ul)
+            && utils::one_of(p.nodes[1].n, 8ul, 16ul)) {
+        nstl::swap(p.nodes[0], p.nodes[1]);
+    }
+}
+
+jit_blk_reorder_t::jit_blk_reorder_t(const pd_t *apd) : primitive_t(apd) {}
+jit_blk_reorder_t::~jit_blk_reorder_t() = default;
+
+status_t jit_blk_reorder_t::init(engine_t *engine) {
+    kernel_ = utils::make_unique<tr::jit_single_blk_kernel_t>(pd()->prb_);
+    return kernel_->create_kernel();
+}
+
+status_t jit_blk_reorder_t::execute(const exec_ctx_t &ctx) const {
+    const auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM);
+    auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO);
+    DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM);
+    DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO);
+
+    // kernel handle 2-dimension tiles, a tail is possible
+    auto &prb = this->pd()->prb_;
+    ptrdiff_t BH = 1;
+    for (int i = 2; i < prb.ndims; ++i) {
+        BH *= prb.nodes[i].n;
+    }
+
+    auto block_sz = prb.n(0);
+    auto n1 = prb.n(1);
+    auto i1 = prb.is(1);
+    auto o1 = prb.os(1);
+    auto FL = (n1 + block_sz - 1) / block_sz;
+    auto bh_stride = BH == 1 ? 0 : prb.is(2);
+
+    auto itype_sz_ = data_type_size(pd()->prb_.itype);
+    auto otype_sz_ = data_type_size(pd()->prb_.otype);
+
+    parallel_nd(BH, FL, [&](dim_t bh, dim_t fl) {
+        auto fl_b = fl * block_sz;
+        auto bh_b = bh_stride * bh;
+        auto *i = in + (bh_b + fl_b * i1) * itype_sz_;
+        auto *o = out + (bh_b + fl_b * o1) * otype_sz_;
+        (*kernel_)(i, o, n1 - fl_b < block_sz, src_zp, dst_zp);
+    });
 
     return status::success;
 }
diff --git a/src/cpu/aarch64/jit_uni_reorder.hpp b/src/cpu/aarch64/jit_uni_reorder.hpp
index 2fb6f0f89f3..bf400430ba5 100644
--- a/src/cpu/aarch64/jit_uni_reorder.hpp
+++ b/src/cpu/aarch64/jit_uni_reorder.hpp
@@ -1,6 +1,6 @@
 /*******************************************************************************
-* Copyright 2018-2020 Intel Corporation
-* Copyright 2020 FUJITSU LIMITED
+* Copyright 2018-2022 Intel Corporation
+* Copyright 2020-2022 FUJITSU LIMITED
 * Copyright 2022 Arm Ltd. and affiliates
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
@@ -36,15 +36,76 @@ namespace tr {
 constexpr int max_ndims = DNNL_MAX_NDIMS;
 
 struct node_t {
-    size_t n;
-    ptrdiff_t is; // input stride
-    ptrdiff_t os; // output stride
-    ptrdiff_t ss; // scale stride
+    static constexpr int64_t empty_field = -1;
+
+    size_t n = 0;
+    size_t tail_size = 0;
+    int dim_id = empty_field;
+    int parent_node_id = empty_field;
+    bool is_zero_pad_needed = false;
+    ptrdiff_t is = 0; // input stride
+    ptrdiff_t os = 0; // output stride
+    ptrdiff_t ss = 0; // scale stride
+    ptrdiff_t cs = 0; // compensation stride
+
+    bool is_dim_id_empty() const { return dim_id == empty_field; }
+    bool is_parent_empty() const { return parent_node_id == empty_field; }
 };
 
 enum class scale_type_t { NONE, COMMON, MANY };
 
 struct prb_t {
+    /* The compensation mask value indicates how big an additional buffer should be.
+     * Possible values for reorder:
+     *     1) standard compensation = 1 = 0b01
+     *     2) asymmetric compensation = 2 = 0b10
+     *     3) compensation if tensor contains group = 3 = 0b11 */
+    static constexpr int invalid_comp_mask = 0;
+    static constexpr int standard_comp_mask = 0b1;
+    static constexpr int asymmetric_comp_mask = 0b10;
+    static constexpr int comp_mask_with_groups
+            = standard_comp_mask + asymmetric_comp_mask;
+
+    bool is_tail_in_one_of_child_nodes(int parent_node_id) const {
+        for (int i = parent_node_id; i >= 0; i--) {
+            if (nodes[i].parent_node_id == parent_node_id) {
+                if (nodes[i].tail_size != 0)
+                    return true;
+                else
+                    parent_node_id = i;
+            }
+        }
+
+        return false;
+    }
+
+    int tail(int d) const {
+        assert(d < ndims);
+        return static_cast<int>(nodes[d].tail_size);
+    }
+
+    int n(int d) const {
+        assert(d < ndims);
+        return static_cast<int>(nodes[d].n);
+    }
+    int is(int d) const {
+        assert(d < ndims);
+        return static_cast<int>(nodes[d].is);
+    }
+    int os(int d) const {
+        assert(d < ndims);
+        return static_cast<int>(nodes[d].os);
+    }
+    int ss(int d) const {
+        assert(d < ndims);
+        return static_cast<int>(nodes[d].ss);
+    }
+
+    int cs(int d) const {
+        assert(d < ndims);
+        return static_cast<int>(nodes[d].cs);
+    }
+
     data_type_t itype;
     data_type_t otype;
     int ndims;
@@ -54,21 +115,24 @@ struct prb_t {
     scale_type_t scale_type;
     float beta;
     int full_ndims;
-    int ip_tail;
-    int op_tail;
-    int iblock;
-    int oblock;
-    int blk_chunk_idx;
+    bool is_tail_present = false;
+    float scale_adjust = 1.f;
+    int compensation_mask = invalid_comp_mask;
+    bool req_s8s8_comp = false;
+    bool req_asymmetric_comp = false;
+    bool req_src_zp = false;
+    bool req_dst_zp = false;
 };
 
 status_t prb_init(prb_t &prb, const memory_desc_t &imd,
         const memory_desc_t &omd, const primitive_attr_t *attr);
 
-status_t prb_check_blk(prb_t &prb, const memory_desc_t &imd);
-
 /** sorts the problem nodes so that output strides come in ascending order */
 void prb_normalize(prb_t &p);
 
+/** fill parent node info for blocked nodes */
+void prb_node_dependency(prb_t &p);
+
 /** folds nodes together if possible */
 void prb_simplify(prb_t &p);
 
@@ -88,10 +152,24 @@ void prb_node_move(prb_t &p, int d0, int d1);
 void prb_dump(const prb_t &p);
 
 struct call_param_t {
-    const void *in;
-    void *out;
-    const float *scale;
-    size_t blk_chunks;
+    const void *in = nullptr;
+    void *out = nullptr;
+    const float *scale = nullptr;
+    int32_t src_zp = 0;
+    int32_t dst_zp = 0;
+    int32_t *compensation_scratch = nullptr;
+};
+
+// The additional structure is needed because
+// using a data structure with tail processing
+// data for non-tail cases reduces kernel
+// performance. This is because there is too
+// much data that has to be transferred to the kernel.
+struct tail_call_param_t {
+    call_param_t base_params;
+    int64_t curr_data_chunks[DNNL_MAX_NDIMS] = {-1};
+    int64_t zeroing_data = static_cast<int64_t>(false);
+    int64_t skip_kernel_execution = static_cast<int64_t>(false);
 };
 
 struct kernel_t {
@@ -100,8 +178,12 @@ struct kernel_t {
         prb_t prb;
     };
 
-    kernel_t(const desc_t &desc) : desc_(desc) {}
+    kernel_t(const desc_t &desc)
+        : desc_(desc)
+        , compensation_needed_(
+                  desc.prb.req_s8s8_comp || desc.prb.req_asymmetric_comp) {}
     virtual void operator()(const call_param_t *c) const = 0;
+    virtual void operator()(const tail_call_param_t *c) const = 0;
     virtual status_t create_kernel() = 0;
     virtual ~kernel_t() {}
 
@@ -119,10 +201,13 @@ struct kernel_t {
 protected:
     const desc_t desc_;
     const prb_t &prb_ = desc_.prb;
+    bool compensation_needed_ = false;
 };
 
 /* TODO: add trans_t class */
 
+struct jit_single_blk_kernel_t;
+
 } // namespace tr
 
 struct jit_uni_reorder_t : public primitive_t {
@@ -135,8 +220,13 @@ struct jit_uni_reorder_t : public primitive_t {
         tr::prb_t prb_;
         tr::kernel_t::desc_t ker_desc_;
         int nthr_;
+        bool with_groups_ = false;
+
+        status_t init(
+                engine_t *engine, engine_t *src_engine, engine_t *dst_engine);
 
     private:
+        void init_scratchpad();
         static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
                 const primitive_attr_t *attr, engine_t *src_engine,
                 const memory_desc_t *src_md, engine_t *dst_engine,
@@ -151,23 +241,66 @@ struct jit_uni_reorder_t : public primitive_t {
     enum { ndims_driver_max = 4 };
 
 private:
-    void omp_driver_0d(
-            int off, const char *in, char *out, const float *scale) const;
+    void omp_driver_0d(int off, const char *in, char *out, const float *scale,
+            int src_zp, int dst_zp, int32_t *compensation_scratch) const;
     void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out,
-            const float *scale) const;
+            const float *scale, int src_zp, int dst_zp,
+            int32_t *compensation_scratch) const;
     void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out,
-            const float *scale) const;
+            const float *scale, int src_zp, int dst_zp,
+            int32_t *compensation_scratch) const;
     void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out,
-            const float *scale) const;
+            const float *scale, int src_zp, int dst_zp,
+            int32_t *compensation_scratch) const;
     void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out,
-            const float *scale) const;
+            const float *scale, int src_zp, int dst_zp,
+            int32_t *compensation_scratch) const;
+
+    void omp_driver(const char *in, char *out, const float *scale, int src_zp,
+            int dst_zp, const memory_tracking::grantor_t &scratchpad) const;
 
-    void omp_driver(const char *in, char *out, const float *scale) const;
+    void fill_curr_data_chunks(const tr::prb_t &prb, const int off,
+            const ptrdiff_t *omp_data_chunks, const int omp_ndims,
+            tr::tail_call_param_t &c) const;
+
+    void reduce_compensation(char *out,
+            const int32_t *compensation_reduce_scratch, const int nthr,
+            const dim_t wspace_per_thr_size) const;
 
     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
     std::unique_ptr<tr::kernel_t> kernel_;
 };
 
+struct jit_blk_reorder_t : public primitive_t {
+    using primitive_t::primitive_t;
+    struct pd_t : public cpu_reorder_pd_t {
+        using cpu_reorder_pd_t::cpu_reorder_pd_t;
+        DECLARE_COMMON_PD_T("jit:blk", jit_blk_reorder_t);
+
+        tr::prb_t prb_;
+
+    private:
+        static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
+                const primitive_attr_t *attr, engine_t *src_engine,
+                const memory_desc_t *src_md, engine_t *dst_engine,
+                const memory_desc_t *dst_md);
+
+        // Swap last two nodes, put block 4, 8, 16 nodes to first
+        static void prb_tile_normalize(tr::prb_t &p);
+        friend dnnl::impl::impl_list_item_t;
+    };
+
+    status_t init(engine_t *engine) override;
+    status_t execute(const exec_ctx_t &ctx) const override;
+
+    jit_blk_reorder_t(const pd_t *apd);
+    ~jit_blk_reorder_t();
+
+private:
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
+    std::unique_ptr<tr::jit_single_blk_kernel_t> kernel_;
+};
+
 } // namespace aarch64
 } // namespace cpu
 } // namespace impl
diff --git a/src/cpu/aarch64/jit_uni_reorder_utils.cpp b/src/cpu/aarch64/jit_uni_reorder_utils.cpp
index 7123811f827..28f36a7e2e7 100644
--- a/src/cpu/aarch64/jit_uni_reorder_utils.cpp
+++ b/src/cpu/aarch64/jit_uni_reorder_utils.cpp
@@ -1,6 +1,6 @@
 /*******************************************************************************
-* Copyright 2018-2021 Intel Corporation
-* Copyright 2020 FUJITSU LIMITED
+* Copyright 2018-2022 Intel Corporation
+* Copyright 2020-2022 FUJITSU LIMITED
 * Copyright 2022 Arm Ltd. and affiliates
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
@@ -25,10 +25,21 @@
 #include "common/nstl.hpp"
 #include "common/type_helpers.hpp"
 #include "common/utils.hpp"
-#include "dnnl_debug.h"
+#include "oneapi/dnnl/dnnl_debug.h"
 
 #include "cpu/aarch64/jit_uni_reorder.hpp"
 
+// #define TR_DEBUG
+#if defined(TR_DEBUG)
+#define DEBUg(...) \
+    do { \
+        __VA_ARGS__ \
+    } while (0)
+#else
+#define DEBUg(...)
+#endif
+#define DEBUG(...) DEBUg(__VA_ARGS__)
+
 using namespace dnnl::impl::types;
 using namespace dnnl::impl::status;
 
@@ -41,87 +52,45 @@ namespace tr {
 
 /** ad-hoc structure to describe blocked memory layout */
 struct layout_desc_t {
+    layout_desc_t()
+        : dt(dnnl_data_type_undef)
+        , ndims(0)
+        , id {-1}
+        , dims {0}
+        , tails {0}
+        , is_blk {false}
+        , strides {0} {}
     data_type_t dt;
     int ndims;
     dims_t id;
     dims_t dims;
+    dims_t tails;
+    bool is_blk[DNNL_MAX_NDIMS];
     strides_t strides;
 };
 
-static status_t compute_blk_and_tail(
-        const memory_desc_t &md_, const int idx, int &blk, int &tail) {
-    const auto md = memory_desc_wrapper(md_);
-    const auto &bd = md.blocking_desc();
-    if (tail == 0) return status::success;
-
-    const std::set<dim_t> unique_inner_idxs(
-            bd.inner_idxs, bd.inner_idxs + bd.inner_nblks);
-    std::set<dim_t> dims_with_multiple_blks;
-    for (dim_t dim : unique_inner_idxs) {
-        if (std::count(bd.inner_idxs, bd.inner_idxs + bd.inner_nblks, dim) > 1)
-            dims_with_multiple_blks.insert(dim);
-    }
-
-    // Dims that have a tail and have multiple blocks are not supported by the jit kernel yet.
-    // For example:
-    // src_tag = abcd
-    // dst_tag = ABcd16b16a4b
-    // 16x15x3x3
-    // In this case, 'b' dim has two blocks and has a tail. It is not a supported case.
-    if (dims_with_multiple_blks.find(idx) != dims_with_multiple_blks.end())
-        return status::unimplemented;
-
-    // Only supports inconsistent padding in single and double blocks
-    // and the total block size <= 256
-    for (int iblk = bd.inner_nblks - 1; iblk > 0; --iblk) {
-        if (bd.inner_idxs[iblk] == idx) break;
-        blk *= bd.inner_blks[iblk];
-        tail *= bd.inner_blks[iblk];
-    }
-    if (unique_inner_idxs.size() > 2 || blk > 256) return status::unimplemented;
-
-    return status::success;
-}
-
-static status_t compute_chunk_idx(const prb_t &p, const memory_desc_t &imd_,
-        const memory_desc_t &omd_, const int blk_idx, int &chunk_idx) {
-    const auto imd = memory_desc_wrapper(imd_);
-    const auto omd = memory_desc_wrapper(omd_);
-    const auto &ibd = imd.blocking_desc();
-    const auto &obd = omd.blocking_desc();
-    if (p.ip_tail == 0 && p.op_tail == 0) return status::success;
-
-    const ptrdiff_t is
-            = ibd.strides[blk_idx] * obd.inner_blks[obd.inner_idxs[blk_idx]];
-    const ptrdiff_t os = obd.strides[blk_idx];
-
-    for (int i = blk_idx; i < omd.ndims(); ++i) {
-        if (p.nodes[i].os == os && p.nodes[i].is == is) {
-            chunk_idx = i;
-            return status::success;
-        }
-    }
-
-    return status::invalid_arguments;
-}
-
 status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
-        layout_desc_t &ld, const dims_t &blocks, const dims_t &ext_padding) {
+        layout_desc_t &ld, const dims_t &blocks, const dims_t &external_padding,
+        const dims_t &tails) {
+    static constexpr bool it_is_blk = true;
+
     const auto md = memory_desc_wrapper(md_);
 
-    bool ok = true && md.is_blocking_desc() && md.extra().flags == 0;
-    if (!ok) return invalid_arguments;
+    if (!md.is_blocking_desc()) return invalid_arguments;
 
     const auto &bd = md.blocking_desc();
 
     ld.ndims = 0;
     ld.dt = md.data_type();
 
-    auto P = [&ld](int id, int dim, ptrdiff_t stride) {
+    auto add_dim = [&ld](int id, dim_t dim, dim_t tail, bool is_blk,
+                           ptrdiff_t stride) {
         assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
         ld.id[ld.ndims] = id;
         ld.dims[ld.ndims] = dim;
         ld.strides[ld.ndims] = stride;
+        ld.tails[ld.ndims] = tail;
+        ld.is_blk[ld.ndims] = is_blk;
         ++ld.ndims;
     };
 
@@ -129,12 +98,27 @@ status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
         const int ld_ndims_start = ld.ndims;
         if (blocks[d] != 1) {
             stride_t stride = 1;
+            int tail = tails[d];
             for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) {
-                if (bd.inner_idxs[iblk] == d) P(d, bd.inner_blks[iblk], stride);
+                if (bd.inner_idxs[iblk] == d) {
+                    const dim_t inner_tail = tail % bd.inner_blks[iblk];
+                    add_dim(d, bd.inner_blks[iblk], inner_tail, it_is_blk,
+                            stride);
+                    tail = utils::div_up(tail, bd.inner_blks[iblk]);
+                }
                 stride *= bd.inner_blks[iblk];
             }
         }
-        P(d, (md.padded_dims()[d] + ext_padding[d]) / blocks[d], bd.strides[d]);
+
+        const dim_t dim_with_external_padding
+                = (md.padded_dims()[d] + external_padding[d]) / blocks[d];
+        const dim_t padded_dim = md.padded_dims()[d] / blocks[d];
+        const dim_t tail = dim_with_external_padding != padded_dim
+                ? dim_with_external_padding
+                        - (dim_with_external_padding - padded_dim)
+                : 0;
+
+        add_dim(d, dim_with_external_padding, tail, !it_is_blk, bd.strides[d]);
 
         // TODO: NOW: revisit, do we need a reverse?
         // TODO: NOW: consider using strides instead of block sizes in md
@@ -144,12 +128,70 @@ status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
             const int idx1 = ld.ndims - 1 - ld_d;
             nstl::swap(ld.dims[idx0], ld.dims[idx1]);
             nstl::swap(ld.strides[idx0], ld.strides[idx1]);
+            nstl::swap(ld.tails[idx0], ld.tails[idx1]);
+            nstl::swap(ld.is_blk[idx0], ld.is_blk[idx1]);
         }
     }
 
     return success;
 }
 
+static bool is_with_groups(const memory_desc_t &dst_md) {
+    using namespace memory_extra_flags;
+    auto dst_d = memory_desc_wrapper(dst_md);
+    const int grp_bit = 1 << 1;
+    auto check_flag_and_mask = [&](int flag, int mask) {
+        return (dst_d.extra().flags & flag) && (mask & grp_bit);
+    };
+
+    return check_flag_and_mask(
+                   compensation_conv_s8s8, dst_d.extra().compensation_mask)
+            || check_flag_and_mask(compensation_conv_asymmetric_src,
+                    dst_d.extra().asymm_compensation_mask);
+}
+
+static inline int get_next_parent_node(node_t *nodes, int ndims, int cur_node) {
+    const int cur_id = nodes[cur_node].dim_id;
+    for (int d = cur_node + 1; d < ndims; ++d) {
+        if (nodes[d].dim_id == cur_id) return d;
+    }
+    return -1;
+}
+
+static void prb_set_compensation_strides(prb_t &p) {
+
+    auto require_n_stride = [&](int cur_node) -> bool {
+        const int parent = get_next_parent_node(p.nodes, p.ndims, cur_node);
+        if (parent < 0) return false;
+
+        const size_t p_n = p.nodes[parent].n;
+
+        // if 'parent_node.n' is larger than 1, then cur_node stride
+        // is 'cur_node.n'
+        return p_n > size_t(1);
+    };
+
+    const auto compensation_needed = p.req_s8s8_comp || p.req_asymmetric_comp;
+    if (!compensation_needed) return;
+    int mask = p.compensation_mask;
+    ptrdiff_t cs = 1;
+    for (int d = 0; d < p.ndims; ++d) {
+        if (mask & (1 << p.nodes[d].dim_id)) {
+
+            // correct cases when 'cs' exceeds output stride
+            if (cs > p.nodes[d].os) cs = p.nodes[d].os;
+
+            p.nodes[d].cs = cs;
+            const bool n_stride = require_n_stride(d);
+            if (p.nodes[d].tail_size > 0 && (!p.nodes[d].is_zero_pad_needed)
+                    && (!n_stride))
+                cs *= p.nodes[d].tail_size;
+            else
+                cs *= p.nodes[d].n;
+        }
+    }
+}
+
 status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
         const primitive_attr_t *attr) {
     auto im_d = memory_desc_wrapper(imd);
@@ -157,8 +199,7 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
 
     auto check_post_ops = [](const primitive_attr_t *attr) {
         const auto &po = attr->post_ops_;
-        return po.len() == 0
-                || (po.len() == 1 && po.contain(primitive_kind::sum, 0));
+        return po.len() == 0 || (po.len() == 1 && po.entry_[0].is_sum(false));
     };
 
     bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc()
@@ -166,81 +207,129 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
             && !om_d.has_runtime_dims_or_strides() && !om_d.has_zero_dim()
             && attr->has_default_values(
                     primitive_attr_t::skip_mask_t::oscale_runtime
+                    | primitive_attr_t::skip_mask_t::zero_points_runtime
                     | primitive_attr_t::skip_mask_t::post_ops)
             && check_post_ops(attr);
     if (!ok) return unimplemented;
 
-    dims_t iblocks, oblocks, ip_padding, op_padding;
+    bool is_tail_present = false;
+    dims_t iblocks, oblocks, i_tails, o_tails, i_paddings, o_paddings;
     im_d.compute_blocks(iblocks);
     om_d.compute_blocks(oblocks);
-    utils::array_set(ip_padding, 0, im_d.ndims());
-    utils::array_set(op_padding, 0, om_d.ndims());
-
-    /* padding_dim consistency check
-     * only supports inconsitent padding for src
-     * TODO: Add inconsistent padding support for dst */
-    int ip_tail = 0;
-    int op_tail = 0;
-    int iblk_w_tail = 1;
-    int oblk_w_tail = 1;
-    int blk_idx = 0;
+
+    for (int d = 0; d < om_d.ndims(); ++d) {
+        const auto dim = om_d.dims()[d];
+        const auto pdim = om_d.padded_dims()[d];
+        const auto cblock = oblocks[d];
+        // do not allow excess pdim other than required for rounding-up of dim.
+        if (utils::rnd_up(dim, cblock) != pdim) return unimplemented;
+    }
+
+    utils::array_set(i_tails, 0, im_d.ndims());
+    utils::array_set(o_tails, 0, om_d.ndims());
+    utils::array_set(i_paddings, 0, im_d.ndims());
+    utils::array_set(o_paddings, 0, om_d.ndims());
 
     for (int d = 0; d < im_d.ndims(); ++d) {
-        const int ip_tmp_dim = im_d.padded_dims()[d];
-        const int op_tmp_dim = om_d.padded_dims()[d];
-        const int ip_tmp_tail = ip_tmp_dim % oblocks[d];
-        const int op_tmp_tail = op_tmp_dim % iblocks[d];
-
-        const bool pdim_consistent = ip_tmp_dim == op_tmp_dim
-                && ip_tmp_tail == 0 && op_tmp_tail == 0;
-        const bool pdim_tail = ip_tmp_tail > 0
-                && (ip_tmp_dim + oblocks[d] - ip_tmp_tail) == op_tmp_dim
-                && op_tmp_tail == 0 && ip_tail == 0;
-        if (!pdim_consistent && !pdim_tail) return status::unimplemented;
-        if (pdim_tail) {
-            blk_idx = d;
-            ip_tail = ip_tmp_tail;
-            op_tail = op_tmp_tail;
-            iblk_w_tail = iblocks[d];
-            oblk_w_tail = oblocks[d];
-            ip_padding[d] = oblocks[d] - ip_tmp_tail;
-            op_padding[d] = iblocks[d] - op_tmp_tail;
+        const dim_t i_dim = im_d.dims()[d];
+        const dim_t o_dim = om_d.dims()[d];
+        const dim_t i_tail = i_dim % iblocks[d];
+        const dim_t o_tail = o_dim % oblocks[d];
+
+        if (o_tail > 0) {
+            is_tail_present = true;
+            o_tails[d] = o_tail;
+            o_paddings[d] = oblocks[d] - o_tail;
+        }
+
+        if (i_tail > 0) {
+            is_tail_present = true;
+            i_tails[d] = i_tail;
+            i_paddings[d] = iblocks[d] - i_tail;
         }
     }
-    CHECK(compute_blk_and_tail(omd, blk_idx, oblk_w_tail, ip_tail));
 
+    // To compute input layout description we need to pass output paddings
+    // which will be used to compute input dims rounded up to multiple of
+    // output dims. Analogous applies to output layout description.
+    // This is demanded by the algorithm of nodes creation.
+    // Example:
+    // input:
+    //  format: abc
+    //  size: 77, 15, 3
+    //  o_padding: 3, 17, 0
+    //  returns ild: 80, 32, 3
+    // output:
+    //  format: ABc16b16a2b
+    //  size: 77, 15, 3
+    //  i_padding: 0, 0, 0
+    //  returns old: 5, 16, 1, 16, 2, 3
     layout_desc_t ild, old;
-    status_t status
-            = cvt_mem_desc_to_layout_desc(imd, ild, iblocks, ip_padding);
-    if (status != success) return status;
-    status = cvt_mem_desc_to_layout_desc(omd, old, oblocks, op_padding);
-    if (status != success) return status;
+    CHECK(cvt_mem_desc_to_layout_desc(imd, ild, iblocks, o_paddings, i_tails));
+    CHECK(cvt_mem_desc_to_layout_desc(omd, old, oblocks, i_paddings, o_tails));
 
     p.itype = ild.dt;
     p.otype = old.dt;
-    p.ip_tail = ip_tail;
-    p.op_tail = op_tail;
-    p.iblock = iblk_w_tail;
-    p.oblock = oblk_w_tail;
-
+    p.is_tail_present = is_tail_present;
+    p.req_src_zp = !attr->zero_points_.has_default_values(DNNL_ARG_SRC);
+    p.req_dst_zp = !attr->zero_points_.has_default_values(DNNL_ARG_DST);
     p.scale_type = attr->output_scales_.has_default_values()
             ? scale_type_t::NONE
             : (attr->output_scales_.mask_ == 0 ? scale_type_t::COMMON
                                                : scale_type_t::MANY);
+    p.scale_adjust = (om_d.extra().flags & memory_extra_flags::scale_adjust)
+            ? om_d.extra().scale_adjust
+            : 1.f;
+    p.req_s8s8_comp
+            = om_d.extra().flags & memory_extra_flags::compensation_conv_s8s8;
+    p.req_asymmetric_comp = om_d.extra().flags
+            & memory_extra_flags::compensation_conv_asymmetric_src;
+
+    const bool with_groups = is_with_groups(omd);
+
+    auto mask_ok = [&](bool check, int mask) {
+        return IMPLICATION(check, mask == (with_groups ? 0x3 : 0x1));
+    };
+
+    if (!mask_ok(p.req_s8s8_comp, om_d.extra().compensation_mask)
+            || !mask_ok(p.req_asymmetric_comp,
+                    om_d.extra().asymm_compensation_mask))
+        return status::unimplemented;
 
-    ptrdiff_t ss[max_ndims] = {0};
+    ptrdiff_t ss[max_ndims] = {0}; // scales strides
     if (p.scale_type == scale_type_t::MANY) {
-        ptrdiff_t last_ss = 1;
+        const int mask = attr->output_scales_.mask_;
+        ptrdiff_t dense_stride = 1;
+        ptrdiff_t last_stride = 1;
         for (int d = old.ndims - 1; d >= 0; --d) {
             assert((d == 0 || old.id[d - 1] <= old.id[d])
                     && "logical dimensions should be in ascending order");
-            if (attr->output_scales_.mask_ & (1 << old.id[d])) {
-                ss[d] = last_ss;
-                last_ss *= old.dims[d];
+            if (mask & (1 << old.id[d])) {
+                if ((d + 1) < old.ndims && old.id[d + 1] != old.id[d]
+                        && (mask & (1 << old.id[d + 1]))) {
+                    dense_stride = dense_stride * imd.dims[old.id[d + 1]];
+                    last_stride = dense_stride;
+                }
+                ss[d] = last_stride;
+                last_stride *= old.dims[d];
             }
         }
     }
 
+    const auto compensation_needed = p.req_s8s8_comp || p.req_asymmetric_comp;
+    if (compensation_needed) {
+        p.compensation_mask = p.req_s8s8_comp
+                ? om_d.extra().compensation_mask
+                : (p.req_asymmetric_comp ? om_d.extra().asymm_compensation_mask
+                                         : tr::prb_t::invalid_comp_mask);
+
+        if (p.compensation_mask == tr::prb_t::asymmetric_comp_mask)
+            return unimplemented;
+
+        assert(p.compensation_mask == tr::prb_t::standard_comp_mask
+                || p.compensation_mask == tr::prb_t::comp_mask_with_groups);
+    }
+
     int ndims = 0;
 
     int i_pos = 0; /* state for input  -- current dimension */
@@ -254,6 +343,10 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
 
         if (ild.dims[i_pos] == old.dims[o_pos]) {
             p.nodes[ndims].n = ild.dims[i_pos];
+            p.nodes[ndims].dim_id = old.id[o_pos];
+            p.nodes[ndims].tail_size = old.tails[o_pos];
+            p.nodes[ndims].is_zero_pad_needed
+                    = old.is_blk[o_pos] && old.tails[o_pos] > 0;
             p.nodes[ndims].is = ild.strides[i_pos];
             p.nodes[ndims].os = old.strides[o_pos];
             p.nodes[ndims].ss = ss[o_pos];
@@ -261,19 +354,45 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
             ++i_pos;
             ++o_pos;
         } else if (ild.dims[i_pos] < old.dims[o_pos]) {
-            assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
-            int factor = old.dims[o_pos] / ild.dims[i_pos];
+            // old must be divisible by ild or we will not be
+            // able to create valid nodes. The problem appears
+            // when stag=Acdb48a and dtag=Acdb32a for example.
+            if (ild.dims[i_pos] == 0 || old.dims[o_pos] % ild.dims[i_pos] != 0)
+                return status::unimplemented;
+
+            dim_t factor = old.dims[o_pos] / ild.dims[i_pos];
+
+            const size_t tail_of_upper_dim
+                    = utils::div_up(old.tails[o_pos], factor) == ild.dims[i_pos]
+                    ? 0
+                    : utils::div_up(old.tails[o_pos], factor);
+            const size_t tail_of_lower_dim = old.tails[o_pos] % factor;
+
             p.nodes[ndims].n = ild.dims[i_pos];
+            p.nodes[ndims].dim_id = old.id[o_pos];
+            p.nodes[ndims].tail_size = tail_of_upper_dim;
+            p.nodes[ndims].is_zero_pad_needed
+                    = old.is_blk[o_pos] && tail_of_upper_dim > 0;
             p.nodes[ndims].is = ild.strides[i_pos];
             p.nodes[ndims].os = old.strides[o_pos] * factor;
             p.nodes[ndims].ss = ss[o_pos] * factor;
             ++ndims;
             ++i_pos;
             old.dims[o_pos] = factor;
+            old.tails[o_pos] = tail_of_lower_dim;
         } else if (ild.dims[i_pos] > old.dims[o_pos]) {
-            assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
-            int factor = ild.dims[i_pos] / old.dims[o_pos];
+            // ild must be divisible by old or we will not be
+            // able to create valid nodes. The problem appears
+            // when stag=Acdb32a and dtag=Acdb48a for example.
+            if (old.dims[o_pos] == 0 || ild.dims[i_pos] % old.dims[o_pos] != 0)
+                return status::unimplemented;
+
+            dim_t factor = ild.dims[i_pos] / old.dims[o_pos];
             p.nodes[ndims].n = old.dims[o_pos];
+            p.nodes[ndims].dim_id = old.id[o_pos];
+            p.nodes[ndims].tail_size = old.tails[o_pos];
+            p.nodes[ndims].is_zero_pad_needed
+                    = old.is_blk[o_pos] && old.tails[o_pos] > 0;
             p.nodes[ndims].is = ild.strides[i_pos] * factor;
             p.nodes[ndims].os = old.strides[o_pos];
             p.nodes[ndims].ss = ss[o_pos];
@@ -282,12 +401,9 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
             ild.dims[i_pos] = factor;
         }
     }
-    int blk_chunk_idx = ndims;
-    CHECK(compute_chunk_idx(p, imd, omd, blk_idx, blk_chunk_idx));
 
     p.ndims = ndims;
     p.full_ndims = ndims;
-    p.blk_chunk_idx = blk_chunk_idx;
 
     p.ioff = memory_desc_wrapper(imd).offset0();
     p.ooff = memory_desc_wrapper(omd).offset0();
@@ -295,6 +411,28 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
     const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
     p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
 
+    DEBUG({
+        printf("init : ");
+        prb_dump(prb);
+    });
+    // Sort the prb array in increasing sizes of the output stride
+    prb_normalize(p);
+    DEBUG({
+        printf("norm : ");
+        prb_dump(prb);
+    });
+
+    // compensation strides require prb_normalized
+    prb_set_compensation_strides(p);
+
+    /* Combine the variables, which appear together on both
+             * sides of the reorder */
+    prb_simplify(p);
+    DEBUG({
+        printf("smpl : ");
+        prb_dump(prb);
+    });
+
     return success;
 }
 
@@ -307,28 +445,23 @@ void prb_normalize(prb_t &p) {
                             && p.nodes[j].n < p.nodes[min_pos].n);
             if (new_min) min_pos = j;
         }
-        if (min_pos != d) {
-            nstl::swap(p.nodes[d], p.nodes[min_pos]);
-            if (p.blk_chunk_idx == min_pos || p.blk_chunk_idx == d)
-                p.blk_chunk_idx = p.blk_chunk_idx == min_pos ? d : min_pos;
-        }
+        if (min_pos != d) { nstl::swap(p.nodes[d], p.nodes[min_pos]); }
     }
 }
 
-status_t prb_check_blk(prb_t &p, const memory_desc_t &md_) {
-    const auto md = memory_desc_wrapper(md_);
-    const auto &bd = md.blocking_desc();
-    if (p.ip_tail == 0) return status::success;
-
-    // Check if the inner blocks and p.nodes[blk].n in the firsti nblks
-    // is equivalent in reverse order when has tail in block layout.
-    const int nblk = bd.inner_nblks;
-    for (int iblk = 0; iblk < nblk; ++iblk) {
-        if (bd.inner_blks[nblk - iblk - 1]
-                != static_cast<ptrdiff_t>(p.nodes[iblk].n))
-            return status::unimplemented;
+void prb_node_dependency(prb_t &prb) {
+    for (int i = 0; i < prb.ndims; i++) {
+        tr::node_t &node = prb.nodes[i];
+        node.parent_node_id = node_t::empty_field;
+        for (int j = i + 1; j < prb.ndims; j++) {
+            const tr::node_t &potential_parent_node = prb.nodes[j];
+            if (!potential_parent_node.is_dim_id_empty()
+                    && potential_parent_node.dim_id == node.dim_id) {
+                node.parent_node_id = j;
+                break;
+            }
+        }
     }
-    return status::success;
 }
 
 void prb_simplify(prb_t &p) {
@@ -338,16 +471,25 @@ void prb_simplify(prb_t &p) {
 #pragma GCC diagnostic push
 #pragma GCC diagnostic ignored "-Warray-bounds"
 #endif
+
+    const auto skip_dim_combining = [&p](const int node_id) -> bool {
+        return (p.is_tail_in_one_of_child_nodes(node_id)
+                       && p.nodes[node_id].n > 1)
+                || p.nodes[node_id].tail_size > 0;
+    };
+
+    if (p.is_tail_present) prb_node_dependency(p);
+
     for (int d = 0; d < p.ndims - 1; ++d) {
         auto &this_node = p.nodes[d + 0];
         auto &next_node = p.nodes[d + 1];
-        const bool skip_blk_idx = (p.ip_tail > 0 || p.op_tail > 0)
-                && (p.blk_chunk_idx == d || p.blk_chunk_idx == d + 1);
+        const bool skip_dims_combining
+                = skip_dim_combining(d) || skip_dim_combining(d + 1);
         const bool fold = false
                 || (next_node.n == static_cast<size_t>(1)
-                        && !skip_blk_idx) // trivial case, just drop next node
+                        && !skip_dims_combining) // trivial case, just drop next node
                 || (true // or real folding if possible
-                        && !skip_blk_idx
+                        && !skip_dims_combining
                         && next_node.is
                                 == static_cast<ptrdiff_t>(
                                         this_node.n * this_node.is)
@@ -356,15 +498,20 @@ void prb_simplify(prb_t &p) {
                                         this_node.n * this_node.os)
                         && next_node.ss
                                 == static_cast<ptrdiff_t>(
-                                        this_node.n * this_node.ss));
+                                        this_node.n * this_node.ss)
+                        && next_node.cs
+                                == static_cast<ptrdiff_t>(
+                                        this_node.n * this_node.cs));
         if (fold) {
             this_node.n *= next_node.n;
+            this_node.dim_id = node_t::empty_field;
+            this_node.is_zero_pad_needed = false;
             for (int j = d + 2; j < p.ndims; ++j)
                 p.nodes[j - 1] = p.nodes[j];
-            if (d < p.blk_chunk_idx) --p.blk_chunk_idx;
             --p.ndims;
             --p.full_ndims;
             --d; // make another try
+            if (p.is_tail_present) prb_node_dependency(p);
         }
     }
 #if defined(__GNUC__) && __GNUC__ >= 4
@@ -372,24 +519,42 @@ void prb_simplify(prb_t &p) {
 #endif
 }
 
-void prb_node_split(prb_t &p, int dim, size_t n1) {
+void prb_node_split(prb_t &p, int dim, size_t new_node_size) {
     assert(dim < p.ndims);
     assert(p.ndims < max_ndims);
-    assert(p.nodes[dim].n % n1 == 0);
+    assert(p.nodes[dim].n % new_node_size == 0);
 
     p.ndims += 1;
     p.full_ndims += 1;
-    if (dim < p.blk_chunk_idx) p.blk_chunk_idx += 1;
 
     for (int d = p.ndims; d > dim + 1; --d)
         p.nodes[d] = p.nodes[d - 1];
 
-    p.nodes[dim + 1].n = p.nodes[dim].n / n1;
-    p.nodes[dim + 1].is = p.nodes[dim].is * n1;
-    p.nodes[dim + 1].os = p.nodes[dim].os * n1;
-    p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
-
-    p.nodes[dim].n = n1;
+    const size_t upper_node_size = p.nodes[dim].n / new_node_size;
+    const size_t lower_node_size = new_node_size;
+    p.nodes[dim + 1].n = upper_node_size;
+    p.nodes[dim].n = lower_node_size;
+
+    const bool is_tail = p.nodes[dim].tail_size > 0;
+    const size_t upper_node_tail
+            = utils::div_up(p.nodes[dim].tail_size, lower_node_size)
+                    == upper_node_size
+            ? 0
+            : utils::div_up(p.nodes[dim].tail_size, lower_node_size);
+    const size_t lower_node_tail = p.nodes[dim].tail_size % lower_node_size;
+    p.nodes[dim].tail_size = is_tail ? lower_node_tail : 0;
+    p.nodes[dim + 1].tail_size = is_tail ? upper_node_tail : 0;
+
+    p.nodes[dim + 1].is_zero_pad_needed
+            = p.nodes[dim].is_zero_pad_needed && p.nodes[dim + 1].tail_size > 0;
+    p.nodes[dim].is_zero_pad_needed
+            = p.nodes[dim].is_zero_pad_needed && p.nodes[dim].tail_size > 0;
+
+    p.nodes[dim + 1].dim_id = p.nodes[dim].dim_id;
+    p.nodes[dim + 1].is = p.nodes[dim].is * lower_node_size;
+    p.nodes[dim + 1].os = p.nodes[dim].os * lower_node_size;
+    p.nodes[dim + 1].ss = p.nodes[dim].ss * lower_node_size;
+    p.nodes[dim + 1].cs = p.nodes[dim].cs * lower_node_size;
 }
 
 void prb_node_swap(prb_t &p, int d0, int d1) {
@@ -425,8 +590,11 @@ void prb_dump(const prb_t &p) {
     printf("@@@ type:%s:%s ndims:%d ", dnnl_dt2str(p.itype),
             dnnl_dt2str(p.otype), p.ndims);
     for (int d = 0; d < p.ndims; ++d)
-        printf("[%zu:%td:%td:%td]", p.nodes[d].n, p.nodes[d].is, p.nodes[d].os,
-                p.nodes[d].ss);
+        printf("[%zu:%zu:%d:%d:%s:%td:%td:%td:%td]", p.nodes[d].n,
+                p.nodes[d].tail_size, p.nodes[d].dim_id,
+                p.nodes[d].parent_node_id,
+                p.nodes[d].is_zero_pad_needed ? "true" : "false", p.nodes[d].is,
+                p.nodes[d].os, p.nodes[d].ss, p.nodes[d].cs);
     printf(" off:%zu:%zu\n", p.ioff, p.ooff);
 }
 
diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp
index f51e3c22414..fdefec8a049 100644
--- a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp
+++ b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp
@@ -1,5 +1,6 @@
 /*******************************************************************************
 * Copyright 2020-2022 Intel Corporation
+* Copyright 2022 FUJITSU LIMITED
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
@@ -32,6 +33,7 @@ const impl_list_map_t &regular_f32_f32_impl_list_map() {
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t))
 
+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t))
             DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
             REG_SR(f32, any, f32, any, fmt_order::any, spec::reference)
 
@@ -44,6 +46,7 @@ const impl_list_map_t &regular_f32_f32_impl_list_map() {
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t))
 
+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t))
             DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
             DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCw16c))
             DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCw8c))
@@ -75,6 +78,7 @@ const impl_list_map_t &regular_f32_f32_impl_list_map() {
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t))
 
+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t))
             DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
 
             DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nChw16c))
@@ -123,6 +127,7 @@ const impl_list_map_t &regular_f32_f32_impl_list_map() {
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t))
 
+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t))
             DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
 
             DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCdhw16c))
@@ -171,6 +176,7 @@ const impl_list_map_t &regular_f32_f32_impl_list_map() {
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t))
 
+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t))
             DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
 
 
diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp
index fadbee0ecf8..b1881df80e0 100644
--- a/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp
+++ b/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp
@@ -1,5 +1,6 @@
 /*******************************************************************************
 * Copyright 2020-2022 Intel Corporation
+* Copyright 2022 FUJITSU LIMITED
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
@@ -31,6 +32,7 @@ const impl_list_map_t &regular_f32_s32_impl_list_map() {
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t))
 
+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t))
             DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
             DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s32, nChw16c))
             REG_SR(f32, any, s32, any, fmt_order::any, spec::reference)
diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp
index b83d47b2d6f..6bd305c7b41 100644
--- a/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp
+++ b/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp
@@ -1,5 +1,6 @@
 /*******************************************************************************
 * Copyright 2020-2022 Intel Corporation
+* Copyright 2022 FUJITSU LIMITED
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
@@ -35,6 +36,7 @@ const impl_list_map_t &regular_f32_s8_impl_list_map() {
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t))
 
+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t))
             DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
 
             DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s8, nChw16c))
diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp
index 4bae84307e6..d306c3abeb8 100644
--- a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp
+++ b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp
@@ -1,5 +1,6 @@
 /*******************************************************************************
 * Copyright 2020-2022 Intel Corporation
+* Copyright 2022 FUJITSU LIMITED
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
@@ -33,6 +34,7 @@ const impl_list_map_t &regular_f32_u8_impl_list_map() {
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t))
 
+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t))
             DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
             DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, u8, nChw16c))
             REG_SR(f32, any, u8, any, fmt_order::any, spec::reference)
diff --git a/src/cpu/reorder/cpu_reorder_regular_s32.cpp b/src/cpu/reorder/cpu_reorder_regular_s32.cpp
index 54d65661791..a8197402b0a 100644
--- a/src/cpu/reorder/cpu_reorder_regular_s32.cpp
+++ b/src/cpu/reorder/cpu_reorder_regular_s32.cpp
@@ -1,5 +1,6 @@
 /*******************************************************************************
 * Copyright 2020-2022 Intel Corporation
+* Copyright 2022 FUJITSU LIMITED
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
@@ -34,6 +35,7 @@ const impl_list_map_t &regular_s32_impl_list_map() {
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t))
 
+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t))
             DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
 
             DNNL_NON_X64_ONLY(REG_SR_BIDIR(s32, any, f32, nChw16c))
diff --git a/src/cpu/reorder/cpu_reorder_regular_s8.cpp b/src/cpu/reorder/cpu_reorder_regular_s8.cpp
index f57d01e2009..ce18dc5caf1 100644
--- a/src/cpu/reorder/cpu_reorder_regular_s8.cpp
+++ b/src/cpu/reorder/cpu_reorder_regular_s8.cpp
@@ -1,5 +1,6 @@
 /*******************************************************************************
 * Copyright 2020-2022 Intel Corporation
+* Copyright 2022 FUJITSU LIMITED
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
@@ -41,6 +42,7 @@ const impl_list_map_t &regular_s8_impl_list_map() {
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t))
 
+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t))
             DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
 
             DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, f32, nChw16c))
diff --git a/src/cpu/reorder/cpu_reorder_regular_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_u8.cpp
index 73d731c3b15..87a58872262 100644
--- a/src/cpu/reorder/cpu_reorder_regular_u8.cpp
+++ b/src/cpu/reorder/cpu_reorder_regular_u8.cpp
@@ -1,5 +1,6 @@
 /*******************************************************************************
 * Copyright 2020-2022 Intel Corporation
+* Copyright 2022 FUJITSU LIMITED
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
@@ -35,6 +36,7 @@ const impl_list_map_t &regular_u8_impl_list_map() {
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t))
 
+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t))
             DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
 
             DNNL_NON_X64_ONLY(REG_SR_BIDIR(u8, any, f32, nChw16c))
